Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dd6a8de4
Commit
dd6a8de4
authored
Apr 06, 2022
by
Jehandad Khan
Browse files
Merge branch 'develop' into jd/dev_pkg
parents
0aa899aa
abf4bdb9
Changes
470
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2522 additions
and
220 deletions
+2522
-220
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
+184
-10
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+29
-30
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
+40
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
...ude/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
+50
-0
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+747
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+5
-5
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
...tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
...peration/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
+2
-4
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
.../gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
.../device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
+691
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+6
-6
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+6
-6
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+562
-0
include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp
include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
...nsor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+10
-7
include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp
...k/tensor_operation/gpu/device/device_reduce_blockwise.hpp
+85
-66
include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp
...ration/gpu/device/device_reduce_blockwise_second_call.hpp
+52
-42
include/ck/tensor_operation/gpu/device/device_reduce_common.hpp
...e/ck/tensor_operation/gpu/device/device_reduce_common.hpp
+46
-37
No files found.
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
dd6a8de4
...
@@ -44,7 +44,7 @@ template <typename InDataType,
...
@@ -44,7 +44,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
...
@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
...
@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_wi_c_grid_desc
=
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
...
@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
...
@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_hi_wi_c_grid_desc
=
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
...
@@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
ck
::
index_t
gemm_k0
=
gemm_k
/
GemmK1Number
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmmraw_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_right_pad_transform
(
gemm_m
,
gemm_m_pad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
4
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
gemm_k0
),
make_right_pad_transform
(
gemm_m
,
gemm_m_pad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
else
{
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmmraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
gemm_k0
,
GemmK1Number
)),
make_pass_through_transform
(
gemm_m
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
gemm_k0
),
make_right_pad_transform
(
gemm_m
,
gemm_m_pad
),
make_pass_through_transform
(
GemmK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
}
static
index_t
GetGemmMRaw
(
ck
::
index_t
N
,
static
index_t
GetGemmMRaw
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
{
{
...
@@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
...
@@ -457,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -457,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -593,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -593,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
,
hipStream_t
stream_id
=
nullptr
,
bool
measure_time
=
false
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
,
hipStream_t
stream_id
=
nullptr
,
bool
measure_time
=
false
)
{
{
#if 0
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...
@@ -605,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -605,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
@@ -708,8 +865,24 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -708,8 +865,24 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
// Input tensors can't be bigger than 2GB each.
constexpr
std
::
size_t
GB2
=
2
*
1e9
;
if
(
arg
.
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
()
>
GB2
)
{
return
false
;
}
if
(
arg
.
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
()
>
GB2
)
{
return
false
;
}
if
(
arg
.
c_grid_desc_m_n_
.
GetElementSpaceSize
()
>
GB2
)
{
return
false
;
}
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
...
@@ -722,7 +895,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -722,7 +895,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
...
@@ -855,7 +1028,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -855,7 +1028,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
getConvFwdSpecializationStr
(
ConvForwardSpecialization
)
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
dd6a8de4
#ifndef DEVICE_GEMM_HPP
#pragma once
#define DEVICE_GEMM_HPP
#include <iostream>
#include <iostream>
#include <vector>
#include "device_base.hpp"
#include "device_base.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
struct
GemmShape
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmBias
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
ck
::
index_t
M
,
N
,
K
;
MakeArgumentPointer
(
const
void
*
p_a
,
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
const
void
*
p_b
,
const
void
*
p_bias
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasPtr
=
std
::
unique_ptr
<
DeviceGemmBias
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
...
@@ -65,7 +42,29 @@ template <typename AElementwiseOperation,
...
@@ -65,7 +42,29 @@ template <typename AElementwiseOperation,
using
DeviceGemmPtr
=
std
::
unique_ptr
<
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGroupedGemmPtr
=
std
::
unique_ptr
<
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_gemm_bias.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmBias
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_bias
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasPtr
=
std
::
unique_ptr
<
DeviceGemmBias
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
>
struct
DeviceGemmReduce
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_d0
,
void
*
p_d1
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
ck
::
index_t
BatchCount
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
>
using
DeviceGemmReducePtr
=
std
::
unique_ptr
<
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_gemm_reduce.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
DDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
struct
DeviceGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
>
{
using
DeviceOp
=
DeviceGemmReduce_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
// assume D is packed tensor
static
auto
MakeDGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
d_grid_desc_mraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
DGridDesc_M
=
decltype
(
MakeDGridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
ReduceAccDataType
,
DDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
DGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
DDataType
*
p_d0_grid
,
DDataType
*
p_d1_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_d0_grid_
{
p_d0_grid
},
p_d1_grid_
{
p_d1_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
d_grid_desc_m_
{
DeviceOp
::
MakeDGridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d_grid_desc_mblock_mperblock_
{},
block_2_ctile_map_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
d0_reduce_op_
{
d0_reduce_op
},
d1_reduce_op_
{
d1_reduce_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
d_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
MakeDGridDescriptor_MBlock_MPerBlock
(
d_grid_desc_m_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
DDataType
*
p_d0_grid_
;
DDataType
*
p_d1_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
DGridDesc_M
d_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
D0ReduceOperation
d0_reduce_op_
;
D1ReduceOperation
d1_reduce_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
/* nrepeat */
=
1
)
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
DDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_d1_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
d1_reduce_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
DDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_d0_grid_
,
arg
.
p_d1_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
d0_reduce_op_
,
arg
.
d1_reduce_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d_grid_desc_mblock_mperblock_
,
arg
.
block_2_ctile_map_
);
}
return
0
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
DDataType
*
p_d0
,
DDataType
*
p_d1
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
p_d0
,
p_d1
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
void
*
p_d0
,
void
*
p_d1
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
D0ReduceOperation
d0_reduce_op
,
D1ReduceOperation
d1_reduce_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
DDataType
*>
(
p_d0
),
static_cast
<
DDataType
*>
(
p_d1
),
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmReduce_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
dd6a8de4
...
@@ -27,7 +27,7 @@ template <typename ADataType,
...
@@ -27,7 +27,7 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
_t
GemmSpec
ialization
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -80,7 +80,7 @@ struct DeviceGemmXdl
...
@@ -80,7 +80,7 @@ struct DeviceGemmXdl
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
...
@@ -119,7 +119,7 @@ struct DeviceGemmXdl
...
@@ -119,7 +119,7 @@ struct DeviceGemmXdl
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -154,7 +154,7 @@ struct DeviceGemmXdl
...
@@ -154,7 +154,7 @@ struct DeviceGemmXdl
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -186,7 +186,7 @@ struct DeviceGemmXdl
...
@@ -186,7 +186,7 @@ struct DeviceGemmXdl
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
View file @
dd6a8de4
...
@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
...
@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
View file @
dd6a8de4
...
@@ -4,9 +4,7 @@
...
@@ -4,9 +4,7 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device.hpp"
#include "device.hpp"
#include "device_base.hpp"
#include "device_gemm_bias.hpp"
#include "device_gemm.hpp"
#include "device_gemm_xdl.hpp"
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -141,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
...
@@ -141,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
View file @
dd6a8de4
...
@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
...
@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
View file @
dd6a8de4
...
@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
...
@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
struct
DeviceGemm_Xdl_CShuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemm_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
if
(
nrepeat
==
0
)
{
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
}
else
{
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
}
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
if
(
nrepeat
==
0
)
{
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
}
else
{
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
dd6a8de4
...
@@ -31,7 +31,7 @@ template <typename ADataType,
...
@@ -31,7 +31,7 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
_t
GemmSpec
ialization
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK
...
@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK
...
@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK
...
@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK
...
@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK
...
@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
dd6a8de4
...
@@ -31,7 +31,7 @@ template <typename ADataType,
...
@@ -31,7 +31,7 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
_t
GemmSpec
ialization
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
0 → 100644
View file @
dd6a8de4
#ifndef DEVICE_GROUPED_GEMM_XDL_HPP
#define DEVICE_GROUPED_GEMM_XDL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
index_t
MaxGroupCount
=
10
>
struct
DeviceGroupedGemmXdl
:
public
DeviceGroupedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
struct
GroupedGemmBlock2CTileMap
{
GroupedGemmBlock2CTileMap
()
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
);
BlockStart_
=
-
1
;
}
GroupedGemmBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
,
ck
::
index_t
BlockStart
)
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
M01
,
N01
);
BlockStart_
=
BlockStart
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_2_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
I0
]
-
BlockStart_
));
}
private:
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
ck
::
index_t
BlockStart_
;
};
struct
GemmDescKernelArg
{
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
GroupedGemmBlock2CTileMap
grouped_gemm_block_2_ctile_map_
;
const
ADataType
*
a_ptr
;
const
BDataType
*
b_ptr
;
CDataType
*
c_ptr
;
ck
::
index_t
BlockStart_
,
BlockEnd_
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
grid_size_
=
0
;
group_count_
=
static_cast
<
int
>
(
gemm_shapes
.
size
());
if
(
!
(
group_count_
==
p_a
.
size
()
&&
group_count_
==
p_b
.
size
()
&&
group_count_
==
p_c
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != P_a/b/c.size"
);
}
gemm_desc_kernel_arg_
.
reserve
(
group_count_
);
for
(
index_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
const
index_t
K
=
gemm_shapes
[
i
].
K
;
const
index_t
StrideA
=
gemm_shapes
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_shapes
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_shapes
[
i
].
StrideC
;
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_k0_n_k1_
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
const
auto
c_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n_
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
const
auto
grouped_gemm_block_2_ctile_map_
=
GroupedGemmBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
,
BlockStart
);
gemm_desc_kernel_arg_
.
push_back
(
GemmDescKernelArg
{
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
grouped_gemm_block_2_ctile_map_
,
static_cast
<
const
ADataType
*>
(
p_a
[
i
]),
static_cast
<
const
BDataType
*>
(
p_b
[
i
]),
static_cast
<
CDataType
*>
(
p_c
[
i
]),
BlockStart
,
BlockEnd
});
}
}
}
// private:
index_t
M01_
;
index_t
N01_
;
index_t
group_count_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
index_t
grid_size_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGroupedGemmXdl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
StaticallyIndexedArray
<
GemmDescKernelArg
,
MaxGroupCount
>
gemm_desc_kernel_arg_arg
;
bool
has_main_k0_block_loop
=
true
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
gemm_desc_kernel_arg_
.
size
())
{
gemm_desc_kernel_arg_arg
(
i
)
=
arg
.
gemm_desc_kernel_arg_
[
i
];
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
auto
K0
=
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k0_block_loop"
);
}
}
});
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
true
,
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_arg_arg
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
false
,
MaxGroupCount
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_arg_arg
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
arg
.
gemm_desc_kernel_arg_
.
size
()
!=
arg
.
group_count_
)
return
false
;
else
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedGemmXdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp
View file @
dd6a8de4
...
@@ -10,7 +10,7 @@ namespace ck {
...
@@ -10,7 +10,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
ck
::
ReduceTensorOp
_t
ReduceOpId
>
template
<
ck
::
ReduceTensorOp
ReduceOpId
>
struct
DevicePool2dFwd
:
public
BaseOperator
struct
DevicePool2dFwd
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
...
@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator
...
@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
ck
::
ReduceTensorOp
_t
ReduceOpId
>
template
<
ck
::
ReduceTensorOp
ReduceOpId
>
using
DevicePool2dFwdPtr
=
std
::
unique_ptr
<
DevicePool2dFwd
<
ReduceOpId
>>
;
using
DevicePool2dFwdPtr
=
std
::
unique_ptr
<
DevicePool2dFwd
<
ReduceOpId
>>
;
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
dd6a8de4
...
@@ -16,7 +16,7 @@ namespace device {
...
@@ -16,7 +16,7 @@ namespace device {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
AccDataType
,
ck
::
ReduceTensorOp
_t
ReduceOpId
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
NeedIndices
,
bool
NeedIndices
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
ReduceMThreadClusterSize
,
ck
::
index_t
ReduceMThreadClusterSize
,
...
@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
...
@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
reduce_lowest_length_
=
window_spatial_lengths
[
1
];
reduce_lowest_length_
=
window_spatial_lengths
[
1
];
// TODO: is this correct?
// TODO: is this correct?
if
constexpr
(
ReduceOpId
==
ck
::
ReduceTensorOp
_t
::
AVG
)
if
constexpr
(
ReduceOpId
==
ck
::
ReduceTensorOp
::
AVG
)
{
{
ck
::
index_t
divider
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
ck
::
index_t
divider
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
in_element_op_
=
InElementwiseOperation
{
divider
};
in_element_op_
=
InElementwiseOperation
{
divider
};
...
...
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
dd6a8de4
...
@@ -16,9 +16,11 @@ namespace device {
...
@@ -16,9 +16,11 @@ namespace device {
template
<
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
template
<
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
struct
DeviceReduce
:
public
BaseOperator
struct
DeviceReduce
:
public
BaseOperator
{
{
virtual
size_t
GetWorkspaceSizeInBytes
(
const
std
::
vector
<
int
>&
inLengths
)
virtual
long_index_t
GetWorkspaceSizeInBytes
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
reduceDims
)
{
{
(
void
)
inLengths
;
(
void
)
inLengths
;
(
void
)
reduceDims
;
return
(
0
);
return
(
0
);
};
};
...
@@ -32,18 +34,19 @@ struct DeviceReduce : public BaseOperator
...
@@ -32,18 +34,19 @@ struct DeviceReduce : public BaseOperator
};
};
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>&
inLengths
,
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
&
in
E
lementwise
O
p
,
const
InElementwiseOperation
in
_e
lementwise
_o
p
,
const
AccElementwiseOperation
&
acc
E
lementwise
O
p
)
=
0
;
const
AccElementwiseOperation
acc
_e
lementwise
_o
p
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp
View file @
dd6a8de4
...
@@ -15,8 +15,8 @@ namespace device {
...
@@ -15,8 +15,8 @@ namespace device {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
int
Rank
,
in
dex_
t
Rank
,
typename
ReduceDim
s
,
index_t
Num
ReduceDim
,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
...
@@ -36,15 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -36,15 +36,20 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
"Invalid thread cluster size assignments!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
using
InvariantDims
=
decltype
(
get_invariant_dims
<
Rank
,
ReduceDim
s
>
())
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
Num
ReduceDim
;
static
constexpr
index_t
s
rcDim
s
=
Rank
;
static
constexpr
index_t
numS
rcDim
=
Rank
;
static
constexpr
index_t
d
stDim
s
=
(
InvariantDim
s
::
Size
()
==
0
)
?
1
:
InvariantDim
s
::
Size
()
;
static
constexpr
index_t
numD
stDim
=
(
Num
InvariantDim
==
0
)
?
1
:
Num
InvariantDim
;
static
constexpr
bool
reduceAllDim
s
=
(
InvariantDim
s
::
Size
()
==
0
);
static
constexpr
bool
reduceAllDim
=
(
Num
InvariantDim
==
0
);
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
@@ -52,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -52,18 +57,18 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
)
const
std
::
vector
<
int
>&
inStrides
)
{
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numS
rcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numS
rcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
s
)
if
constexpr
(
reduceAllDim
)
{
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
s
rcDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numS
rcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
return
transform_tensor_descriptor
(
one_dim_inDesc
,
...
@@ -74,7 +79,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -74,7 +79,10 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
}
}
else
else
{
{
const
auto
toReduceDimLengths
=
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
...
@@ -82,24 +90,26 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -82,24 +90,26 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
inDesc
,
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
toR
educeDimLengths
)),
make_merge_transform
(
r
educeDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}();
}();
const
auto
outerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
innerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
inPad_M
=
const
auto
inPad_K
=
math
::
integer_least_multiple
(
innerLen
,
K_BlockTileSize
)
-
innerLen
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
reduceLength
,
K_BlockTileSize
)
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_grid_desc_m_k
,
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
inPad_M
),
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
innerLen
,
inPad_K
)),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
...
@@ -107,67 +117,70 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -107,67 +117,70 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
const
std
::
vector
<
int
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numD
stDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numD
stDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
d
stDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numD
stDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
outerLen
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
inPad
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
inPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_grid_desc_m
,
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
inPad
)),
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
return
(
out_grid_desc_m_padded
);
};
};
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
int
>&
inLengths
,
Argument
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
InDataType
*
in_dev
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
const
AccElementwiseOperation
acc_elementwise_op
)
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
}
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
{
(
void
)
workspace_dev
;
(
void
)
workspace_dev
;
inLengths_
=
inLengths
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
inStrides
;
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
outLengths_
=
outLengths
;
outStrides_
=
outStrides
;
in_elementwise_op_
=
in_elementwise_op
;
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
acc_elementwise_op_
=
acc_elementwise_op
;
beta_
=
type_convert
<
AccDataType
>
(
beta
);
alpha_
=
static_cast
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
OutDataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
ReduceDim
s
>
(
inLengths
);
get_2d_lengths
<
Rank
,
Num
ReduceDim
>
(
inLengths
_
);
if
constexpr
(
InvariantDim
s
::
Size
()
==
0
)
if
constexpr
(
Num
InvariantDim
==
0
)
invariant_lowest_length
=
1
;
invariant_lowest_length
=
1
;
else
else
invariant_lowest_length
=
inLengths
[
InvariantDims
::
At
(
InvariantDim
s
::
Size
()
-
1
)
];
invariant_lowest_length
=
inLengths
_
[
Num
InvariantDim
-
1
];
reduce_lowest_length
=
inLengths
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
];
reduce_lowest_length
=
inLengths
_
[
R
ank
-
1
];
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
M_BlockTileSize
;
...
@@ -179,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -179,7 +192,7 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
std
::
vector
<
int
>
outStrides_
;
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
Out
DataType
beta_
;
Acc
DataType
beta_
;
const
InDataType
*
in_dev_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
OutDataType
*
out_dev_
;
...
@@ -273,18 +286,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -273,18 +286,22 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
if
constexpr
(
InSrcVectorDim
==
0
)
if
constexpr
(
InSrcVectorDim
==
0
)
{
{
if
constexpr
(
InvariantDims
::
Size
()
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
return
(
false
);
{
if
(
pArg
->
inStrides_
[
InvariantDims
::
At
(
InvariantDims
::
Size
()
-
1
)]
!=
1
)
return
(
false
);
return
(
false
);
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
return
(
false
);
};
}
}
else
else
{
{
if
(
pArg
->
inStrides_
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
]
!=
1
)
if
(
pArg
->
inStrides_
[
R
ank
-
1
]
!=
1
)
return
(
false
);
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
...
@@ -303,23 +320,25 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
...
@@ -303,23 +320,25 @@ struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccEl
};
};
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>&
inLengths
,
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
override
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
inStrides
,
outLengths
,
outLengths
,
outStrides
,
outStrides
,
reduceDims
,
alpha
,
alpha
,
beta
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
const
InDataType
*>
(
in_dev
),
...
...
include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp
View file @
dd6a8de4
...
@@ -15,8 +15,8 @@ namespace device {
...
@@ -15,8 +15,8 @@ namespace device {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
int
Rank
,
in
dex_
t
Rank
,
typename
ReduceDim
s
,
index_t
Num
ReduceDim
,
typename
ReduceOperation
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
...
@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -37,6 +37,10 @@ struct DeviceReduceBlockWiseSecondCall
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
"Invalid thread cluster size assignments!"
);
static_assert
((
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
)
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
...
@@ -45,9 +49,9 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -45,9 +49,9 @@ struct DeviceReduceBlockWiseSecondCall
std
::
is_same
<
InDataType
,
AccDataType
>::
value
,
std
::
is_same
<
InDataType
,
AccDataType
>::
value
,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"
);
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"
);
using
InvariantDims
=
decltype
(
get_invariant_dims
<
Rank
,
ReduceDim
s
>
())
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
Num
ReduceDim
;
static
constexpr
index_t
d
stDim
s
=
(
InvariantDim
s
::
Size
()
==
0
)
?
1
:
InvariantDim
s
::
Size
()
;
static
constexpr
index_t
numD
stDim
=
(
Num
InvariantDim
==
0
)
?
1
:
Num
InvariantDim
;
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
@@ -61,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -61,18 +65,20 @@ struct DeviceReduceBlockWiseSecondCall
const
auto
in_grid_desc_m_k
=
const
auto
in_grid_desc_m_k
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
outerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
innerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
inPad_M
=
const
auto
inPad_K
=
math
::
integer_least_multiple
(
innerLen
,
K_BlockTileSize
)
-
innerLen
;
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
reduceLength
,
K_BlockTileSize
)
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_grid_desc_m_k
,
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
inPad_M
),
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
innerLen
,
inPad_K
)),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
...
@@ -80,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -80,26 +86,27 @@ struct DeviceReduceBlockWiseSecondCall
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
const
std
::
vector
<
int
>&
outStrides
)
{
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numD
stDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numD
stDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
d
stDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numD
stDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
const
auto
outerLen
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
outPad
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
outPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_grid_desc_m
,
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
outPad
)),
make_tuple
(
make_right_pad_transform
(
invariantLength
,
outPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
return
(
out_grid_desc_m_padded
);
};
};
...
@@ -117,18 +124,18 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -117,18 +124,18 @@ struct DeviceReduceBlockWiseSecondCall
AccDataType
*
workspace_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
const
AccElementwiseOperation
&
acc_elementwise_op
)
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
}
:
inLengths_
(
inLengths
),
inStrides_
(
inStrides
),
outLengths_
(
outLengths
),
outStrides_
(
outStrides
),
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
in_elementwise_op_
(
in_elementwise_op
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
{
inLengths_
=
inLengths
;
alpha_
=
type_convert
<
AccDataType
>
(
alpha
);
inStrides_
=
inStrides
;
beta_
=
type_convert
<
AccDataType
>
(
beta
);
outLengths_
=
outLengths
;
outStrides_
=
outStrides
;
in_elementwise_op_
=
in_elementwise_op
;
acc_elementwise_op_
=
acc_elementwise_op
;
alpha_
=
static_cast
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
OutDataType
>
(
beta
);
invariant_total_length
=
inLengths
[
0
];
invariant_total_length
=
inLengths
[
0
];
reduce_total_length
=
inLengths
[
1
];
reduce_total_length
=
inLengths
[
1
];
...
@@ -155,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -155,7 +162,7 @@ struct DeviceReduceBlockWiseSecondCall
std
::
vector
<
int
>
outStrides_
;
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
AccDataType
alpha_
;
Out
DataType
beta_
;
Acc
DataType
beta_
;
const
InDataType
*
in_dev_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
OutDataType
*
out_dev_
;
...
@@ -266,19 +273,22 @@ struct DeviceReduceBlockWiseSecondCall
...
@@ -266,19 +273,22 @@ struct DeviceReduceBlockWiseSecondCall
};
};
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>&
inLengths
,
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
const
void
*
in_dev
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
override
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
{
(
void
)
reduceDims
;
return
std
::
make_unique
<
Argument
>
(
inLengths
,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
inStrides
,
outLengths
,
outLengths
,
...
...
include/ck/tensor_operation/gpu/device/device_reduce_common.hpp
View file @
dd6a8de4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define DEVICE_REDUCE_COMMON_HPP
#define DEVICE_REDUCE_COMMON_HPP
#include <vector>
#include <vector>
#include <cassert>
#include "common_header.hpp"
#include "common_header.hpp"
#include "reduction_enums.hpp"
#include "reduction_enums.hpp"
...
@@ -11,55 +12,30 @@ namespace ck {
...
@@ -11,55 +12,30 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
// template <typename preUnaryOpType, typename posUnaryOpType>
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// using DeviceReducePtr = std::unique_ptr<DeviceReduce<preUnaryOpType, posUnaryOpType>>;
// of reduce dims
template
<
int
Rank
,
int
NumReduceDim
>
template
<
int
Rank
,
typename
ReduceDims
>
std
::
pair
<
size_t
,
size_t
>
get_2d_lengths
(
const
std
::
vector
<
int
>&
inLengths
)
std
::
pair
<
size_t
,
size_t
>
get_2d_lengths
(
const
std
::
vector
<
int
>&
inLengths
)
{
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
size_t
tensor_total_length
=
1
;
size_t
invariant_total_length
=
1
;
size_t
reduce_total_length
=
1
;
size_t
reduce_total_length
=
1
;
static_for
<
0
,
ReduceDims
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
reduce_total_length
*=
inLengths
[
ReduceDims
::
At
(
i
)];
});
static_for
<
0
,
Rank
,
1
>
{}([
&
](
auto
i
)
{
tensor_total_length
*=
inLengths
[
i
.
value
];
});
return
std
::
make_pair
(
tensor_total_length
/
reduce_total_length
,
reduce_total_length
);
constexpr
int
NumInvariantDim
=
Rank
-
NumReduceDim
;
};
template
<
int
x
,
typename
Seq
>
for
(
int
i
=
NumInvariantDim
;
i
<
Rank
;
i
++
)
constexpr
bool
belong
()
reduce_total_length
*=
inLengths
[
i
];
{
bool
inside
=
false
;
static_for
<
0
,
Seq
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
inside
=
(
inside
||
(
x
==
Seq
::
At
(
i
)));
});
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
invariant_total_length
*=
inLengths
[
i
];
return
(
inside
);
return
std
::
make_pair
(
invariant_total_length
,
reduce_total_length
);
};
template
<
int
Rank
,
typename
ReduceDims
,
int
start
=
0
>
constexpr
auto
get_invariant_dims
()
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
if
constexpr
(
start
>=
Rank
)
return
Sequence
<>
{};
else
{
if
constexpr
(
!
belong
<
start
,
ReduceDims
>
())
return
merge_sequences
(
Sequence
<
start
>
{},
get_invariant_dims
<
Rank
,
ReduceDims
,
start
+
1
>
());
else
return
get_invariant_dims
<
Rank
,
ReduceDims
,
start
+
1
>
();
};
};
};
// helper functions using variadic template arguments
// helper functions using variadic template arguments
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
static
auto
make_tuple_from_array_and_index_seq
(
const
std
::
vector
<
int
>&
lengths
,
Sequence
<
Ns
...
>
)
auto
make_tuple_from_array_and_index_seq
(
const
std
::
vector
<
int
>&
lengths
,
Sequence
<
Ns
...
>
)
{
{
return
make_tuple
(
static_cast
<
index_t
>
(
lengths
[
Ns
])...);
return
make_tuple
(
static_cast
<
index_t
>
(
lengths
[
Ns
])...);
};
};
...
@@ -74,6 +50,39 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
...
@@ -74,6 +50,39 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
return
make_tuple_from_array_and_index_seq
(
lengths
,
index_seq
);
return
make_tuple_from_array_and_index_seq
(
lengths
,
index_seq
);
};
};
template
<
index_t
Rank
,
index_t
NumReduceDim
>
std
::
vector
<
int
>
shuffle_tensor_dimensions
(
const
std
::
vector
<
int
>&
origLengthsStrides
,
const
std
::
vector
<
int
>&
reduceDims
)
{
std
::
vector
<
int
>
newLengthsStrides
;
assert
(
Rank
==
origLengthsStrides
.
size
()
&&
NumReduceDim
==
reduceDims
.
size
());
int
reduceFlag
=
0
;
// flag the bits for the reduceDims
for
(
int
i
=
0
;
i
<
NumReduceDim
;
i
++
)
{
reduceFlag
|=
1
<<
reduceDims
[
i
];
};
// collect invariant dimensions
for
(
int
i
=
0
;
i
<
Rank
;
i
++
)
if
((
reduceFlag
&
(
1
<<
i
))
==
0
)
{
newLengthsStrides
.
push_back
(
origLengthsStrides
[
i
]);
};
// collect reduce dimensions
for
(
int
i
=
0
;
i
<
Rank
;
i
++
)
if
((
reduceFlag
&
(
1
<<
i
))
>
0
)
{
newLengthsStrides
.
push_back
(
origLengthsStrides
[
i
]);
};
return
newLengthsStrides
;
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
Prev
1
2
3
4
5
6
7
8
9
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment