Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
07a673c6
Commit
07a673c6
authored
Apr 14, 2022
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into cpu_avx2
parents
c0f698d5
ac0d8066
Changes
307
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
143 additions
and
227 deletions
+143
-227
include/ck/tensor_operation/gpu/device/convolution_utility.hpp
...de/ck/tensor_operation/gpu/device/convolution_utility.hpp
+0
-73
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+25
-25
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+4
-4
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+9
-8
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+4
-4
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+7
-7
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+7
-7
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+7
-7
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+6
-6
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
.../gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
+14
-26
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+3
-3
include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp
...nsor_operation/gpu/device/device_conv_backward_weight.hpp
+3
-3
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
+10
-10
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
+10
-10
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+25
-25
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+5
-5
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
...tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
...peration/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
.../gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
.../device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/convolution_utility.hpp
deleted
100644 → 0
View file @
c0f698d5
#ifndef CONVOLUTION_UTILITY_HPP
#define CONVOLUTION_UTILITY_HPP
#include <vector>
namespace
ck
{
namespace
tensor_operation
{
struct
ConvolutionUtility
{
static
std
::
vector
<
ck
::
index_t
>
ComputeOutputSpatialLengths
(
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_strides
,
std
::
vector
<
ck
::
index_t
>
conv_dilations
,
std
::
vector
<
ck
::
index_t
>
in_left_pads
,
std
::
vector
<
ck
::
index_t
>
in_right_pads
)
{
if
(
input_spatial_lengths
.
size
()
==
2
)
{
assert
(
filter_spatial_lengths
.
size
()
==
2
);
assert
(
conv_strides
.
size
()
==
2
);
assert
(
conv_dilations
.
size
()
==
2
);
assert
(
in_left_pads
.
size
()
==
2
);
assert
(
in_right_pads
.
size
()
==
2
);
const
index_t
YEff
=
(
filter_spatial_lengths
[
0
]
-
1
)
*
conv_dilations
[
0
]
+
1
;
const
index_t
XEff
=
(
filter_spatial_lengths
[
1
]
-
1
)
*
conv_dilations
[
1
]
+
1
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
(
Hi
+
in_left_pads
[
0
]
+
in_right_pads
[
0
]
-
YEff
)
/
conv_strides
[
0
]
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pads
[
1
]
+
in_right_pads
[
1
]
-
XEff
)
/
conv_strides
[
1
]
+
1
;
return
{
Ho
,
Wo
};
}
else
if
(
input_spatial_lengths
.
size
()
==
3
)
{
assert
(
filter_spatial_lengths
.
size
()
==
3
);
assert
(
conv_strides
.
size
()
==
3
);
assert
(
conv_dilations
.
size
()
==
3
);
assert
(
in_left_pads
.
size
()
==
3
);
assert
(
in_right_pads
.
size
()
==
3
);
const
index_t
ZEff
=
(
filter_spatial_lengths
[
0
]
-
1
)
*
conv_dilations
[
0
]
+
1
;
const
index_t
YEff
=
(
filter_spatial_lengths
[
1
]
-
1
)
*
conv_dilations
[
1
]
+
1
;
const
index_t
XEff
=
(
filter_spatial_lengths
[
2
]
-
1
)
*
conv_dilations
[
2
]
+
1
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
(
Di
+
in_left_pads
[
0
]
+
in_right_pads
[
0
]
-
ZEff
)
/
conv_strides
[
0
]
+
1
;
const
index_t
Ho
=
(
Hi
+
in_left_pads
[
1
]
+
in_right_pads
[
1
]
-
YEff
)
/
conv_strides
[
1
]
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pads
[
2
]
+
in_right_pads
[
2
]
-
XEff
)
/
conv_strides
[
2
]
+
1
;
return
{
Do
,
Ho
,
Wo
};
}
else
{
return
{};
}
}
};
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
07a673c6
...
@@ -105,7 +105,7 @@ template <typename ALayout,
...
@@ -105,7 +105,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ReduceOperation
,
GemmSpecialization
_t
GemmSpec
ialization
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -171,8 +171,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad both M and K
// pad both M and K
assert
(
K
%
AK1
==
0
);
assert
(
K
%
AK1
==
0
);
...
@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -195,8 +195,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
// pad M, but not K
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
assert
(
KRaw
%
AK1
==
0
);
...
@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -212,8 +212,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
KPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
NKPadding
)
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
{
// pad K, but not M
// pad K, but not M
assert
(
K
%
AK1
==
0
);
assert
(
K
%
AK1
==
0
);
...
@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -274,8 +274,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const
auto
NPad
=
N
-
NRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
NKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad both N and K
// pad both N and K
assert
(
K
%
BK1
==
0
);
assert
(
K
%
BK1
==
0
);
...
@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -298,8 +298,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
NPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
// pad N, but not K
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
assert
(
KRaw
%
BK1
==
0
);
...
@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -315,8 +315,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
KPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
)
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
{
// pad K, but not N
// pad K, but not N
assert
(
K
%
BK1
==
0
);
assert
(
K
%
BK1
==
0
);
...
@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -377,8 +377,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad M and N
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
...
@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -387,8 +387,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
)
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
{
// pad M, but not N
// pad M, but not N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -397,8 +397,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
NPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
NKPadding
)
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
{
// pad N, but not M
// pad N, but not M
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -422,10 +422,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad M
// pad M
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
...
@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
...
@@ -544,8 +544,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ReduceOperation
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
_t
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
07a673c6
...
@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl
...
@@ -129,7 +129,7 @@ struct DeviceBatchedGemmXdl
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
M
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}
}();
}();
...
@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl
...
@@ -158,7 +158,7 @@ struct DeviceBatchedGemmXdl
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
K
));
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}
}();
}();
...
@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl
...
@@ -183,7 +183,7 @@ struct DeviceBatchedGemmXdl
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
M
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}
}();
}();
...
@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
...
@@ -277,7 +277,7 @@ struct DeviceBatchedGemmXdl
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -52,10 +52,13 @@ template <typename InDataType,
...
@@ -52,10 +52,13 @@ template <typename InDataType,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvWrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
:
public
DeviceConvBwdWeight
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
DeviceOp
=
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
OutDataType
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
using
BDataType
=
InDataType
;
...
@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
// TODO make A/B datatype different
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -209,7 +210,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -250,7 +251,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceConv2d
WrW
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str
<<
"DeviceConv2d
BwdWeight
Xdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -25,7 +25,7 @@ template <typename InDataType,
...
@@ -25,7 +25,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardDataSpecialization
_t
ConvBackwardDataSpecialization
,
ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -131,7 +131,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// A: output tensor
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
...
@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -368,7 +368,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -671,7 +671,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 pad = 0 conv
// check if it's 1x1, stride=1 pad = 0 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -27,7 +27,7 @@ template <
...
@@ -27,7 +27,7 @@ template <
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -125,7 +125,7 @@ struct
...
@@ -125,7 +125,7 @@ struct
const
auto
GemmMPad
=
GemmM
-
GemmMRaw
;
const
auto
GemmMPad
=
GemmM
-
GemmMRaw
;
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// 1x1, stride=1, pad=0
{
// 1x1, stride=1, pad=0
const
index_t
GemmK
=
Y
*
X
*
C
;
const
index_t
GemmK
=
Y
*
X
*
C
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
...
@@ -179,7 +179,7 @@ struct
...
@@ -179,7 +179,7 @@ struct
resi_grid_desc_gemmm_gemmn
);
resi_grid_desc_gemmm_gemmn
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// 1x1, pad=0
{
// 1x1, pad=0
const
index_t
GemmK
=
Y
*
X
*
C
;
const
index_t
GemmK
=
Y
*
X
*
C
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
...
@@ -249,7 +249,7 @@ struct
...
@@ -249,7 +249,7 @@ struct
bias_grid_desc_gemmm_gemmn
,
bias_grid_desc_gemmm_gemmn
,
resi_grid_desc_gemmm_gemmn
);
resi_grid_desc_gemmm_gemmn
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
OddC
)
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
OddC
)
{
// C = odd value
{
// C = odd value
const
index_t
GemmKRaw
=
Y
*
X
*
C
;
const
index_t
GemmKRaw
=
Y
*
X
*
C
;
const
index_t
GemmK
=
math
::
integer_least_multiple
(
GemmKRaw
,
K0PerBlock
*
GemmK1Number
);
const
index_t
GemmK
=
math
::
integer_least_multiple
(
GemmKRaw
,
K0PerBlock
*
GemmK1Number
);
...
@@ -466,7 +466,7 @@ struct
...
@@ -466,7 +466,7 @@ struct
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -811,7 +811,7 @@ struct
...
@@ -811,7 +811,7 @@ struct
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
@@ -823,7 +823,7 @@ struct
...
@@ -823,7 +823,7 @@ struct
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -27,8 +27,8 @@ template <
...
@@ -27,8 +27,8 @@ template <
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
InMemoryDataOperationEnum
_t
OutGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
OutGlobalMemoryDataOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -124,7 +124,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const
auto
GemmMPad
=
GemmM
-
GemmMRaw
;
const
auto
GemmMPad
=
GemmM
-
GemmMRaw
;
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// 1x1, stride=1, pad=0
{
// 1x1, stride=1, pad=0
const
index_t
GemmK
=
Y
*
X
*
C
;
const
index_t
GemmK
=
Y
*
X
*
C
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
...
@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -174,7 +174,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
bias_grid_desc_gemmm_gemmn
);
bias_grid_desc_gemmm_gemmn
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// 1x1, pad=0
{
// 1x1, pad=0
const
index_t
GemmK
=
Y
*
X
*
C
;
const
index_t
GemmK
=
Y
*
X
*
C
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
...
@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -240,7 +240,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
bias_grid_desc_gemmm_gemmn
);
bias_grid_desc_gemmm_gemmn
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
OddC
)
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
OddC
)
{
// C = odd value
{
// C = odd value
const
index_t
GemmKRaw
=
Y
*
X
*
C
;
const
index_t
GemmKRaw
=
Y
*
X
*
C
;
const
index_t
GemmK
=
math
::
integer_least_multiple
(
GemmKRaw
,
K0PerBlock
*
GemmK1Number
);
const
index_t
GemmK
=
math
::
integer_least_multiple
(
GemmKRaw
,
K0PerBlock
*
GemmK1Number
);
...
@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -763,7 +763,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
...
@@ -775,7 +775,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -26,7 +26,7 @@ template <
...
@@ -26,7 +26,7 @@ template <
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -120,7 +120,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const
auto
GemmMPad
=
GemmM
-
GemmMRaw
;
const
auto
GemmMPad
=
GemmM
-
GemmMRaw
;
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// 1x1, stride=1, pad=0
{
// 1x1, stride=1, pad=0
const
index_t
GemmK
=
Y
*
X
*
C
;
const
index_t
GemmK
=
Y
*
X
*
C
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
...
@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -165,7 +165,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
out_gemmm_gemmn_grid_desc
);
out_gemmm_gemmn_grid_desc
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// 1x1, pad=0
{
// 1x1, pad=0
const
index_t
GemmK
=
Y
*
X
*
C
;
const
index_t
GemmK
=
Y
*
X
*
C
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
...
@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -226,7 +226,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
);
out_gemmm_gemmn_grid_desc
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
OddC
)
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
OddC
)
{
// C = odd value
{
// C = odd value
const
index_t
GemmKRaw
=
Y
*
X
*
C
;
const
index_t
GemmKRaw
=
Y
*
X
*
C
;
const
index_t
GemmK
=
math
::
integer_least_multiple
(
GemmKRaw
,
K0PerBlock
*
GemmK1Number
);
const
index_t
GemmK
=
math
::
integer_least_multiple
(
GemmKRaw
,
K0PerBlock
*
GemmK1Number
);
...
@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -424,7 +424,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
AccDataType
,
AccDataType
,
CDataType
,
// TODO: Add ShuffleType for DeviceConv2d
CDataType
,
// TODO: Add ShuffleType for DeviceConv2d
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -733,7 +733,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
...
@@ -745,7 +745,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -25,7 +25,7 @@ template <typename InDataType,
...
@@ -25,7 +25,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -119,7 +119,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// A: input tensor
// A: input tensor
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
...
@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -159,7 +159,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
out_gemmm_gemmn_grid_desc
);
out_gemmm_gemmn_grid_desc
);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// A: input tensor
// A: input tensor
const
auto
in_n_hi_wi_c_grid_desc
=
const
auto
in_n_hi_wi_c_grid_desc
=
...
@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -316,7 +316,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -565,7 +565,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -577,7 +577,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp
View file @
07a673c6
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <iostream>
#include <iostream>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include "conv
olution
_util
ity
.hpp"
#include "conv
_fwd
_util.hpp"
#include "device.hpp"
#include "device.hpp"
#include "device_conv_fwd.hpp"
#include "device_conv_fwd.hpp"
#include "common_header.hpp"
#include "common_header.hpp"
...
@@ -53,36 +53,30 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
...
@@ -53,36 +53,30 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
:
N_
{
N
},
:
params_
{
3
,
K_
{
K
},
N
,
C_
{
C
},
K
,
in_spatial_lengths_
{
input_spatial_lengths
},
C
,
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths
,
input_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
out_spatial_lengths_
{
output_spatial_lengths
},
out_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
in_right_pads_
{
input_right_pads
},
p_in_
{
p_in
},
p_in_
{
p_in
},
p_wei_
{
p_wei
},
p_wei_
{
p_wei
},
p_out_
{
p_out
},
p_out_
{
p_out
},
in_element_op_
{
in_element_op
},
in_element_op_
{
in_element_op
},
wei_element_op_
{
wei_element_op
},
wei_element_op_
{
wei_element_op
},
out_element_op_
{
out_element_op
}
out_element_op_
{
out_element_op
}
{
{
}
}
// private:
// private:
index_t
N_
;
utils
::
conv
::
ConvParams
params_
;
index_t
K_
;
index_t
C_
;
std
::
vector
<
index_t
>
in_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
out_spatial_lengths_
;
std
::
vector
<
index_t
>
out_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_dilations_
;
std
::
vector
<
index_t
>
in_left_pads_
;
std
::
vector
<
index_t
>
in_right_pads_
;
const
InDataType
*
p_in_
;
const
InDataType
*
p_in_
;
const
WeiDataType
*
p_wei_
;
const
WeiDataType
*
p_wei_
;
...
@@ -157,13 +151,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
...
@@ -157,13 +151,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
std
::
vector
<
index_t
>
out_spatial_lengths
=
std
::
vector
<
index_t
>
out_spatial_lengths
=
arg
.
params_
.
GetOutputSpatialLengths
();
ConvolutionUtility
::
ComputeOutputSpatialLengths
(
arg
.
in_spatial_lengths_
,
arg
.
filter_spatial_lengths_
,
arg
.
conv_filter_strides_
,
arg
.
conv_filter_dilations_
,
arg
.
in_left_pads_
,
arg
.
in_right_pads_
);
bool
out_lengths_are_consistent
=
out_spatial_lengths
[
0
]
==
arg
.
out_spatial_lengths_
[
0
]
&&
bool
out_lengths_are_consistent
=
out_spatial_lengths
[
0
]
==
arg
.
out_spatial_lengths_
[
0
]
&&
out_spatial_lengths
[
1
]
==
arg
.
out_spatial_lengths_
[
1
]
&&
out_spatial_lengths
[
1
]
==
arg
.
out_spatial_lengths_
[
1
]
&&
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
07a673c6
...
@@ -83,7 +83,7 @@ template <typename InDataType,
...
@@ -83,7 +83,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
...
@@ -207,7 +207,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
static_assert
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Default
,
static_assert
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Default
,
"Wrong! This specialization not implemented!"
);
"Wrong! This specialization not implemented!"
);
const
auto
in_desc_n_di_hi_wi_c
=
const
auto
in_desc_n_di_hi_wi_c
=
...
@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
...
@@ -287,7 +287,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
InDataType
,
InDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_conv_backward_weight.hpp
View file @
07a673c6
...
@@ -11,7 +11,7 @@ namespace device {
...
@@ -11,7 +11,7 @@ namespace device {
template
<
typename
InElementwiseOperation
,
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
struct
DeviceConv
Wrw
:
public
BaseOperator
struct
DeviceConv
BwdWeight
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
...
@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
...
@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template
<
typename
InElementwiseOperation
,
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
using
DeviceConv
Wrw
Ptr
=
std
::
unique_ptr
<
using
DeviceConv
BwdWeight
Ptr
=
std
::
unique_ptr
<
DeviceConv
Wrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
DeviceConv
BwdWeight
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
07a673c6
...
@@ -25,7 +25,7 @@ template <typename InDataType,
...
@@ -25,7 +25,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardDataSpecialization
_t
ConvBackwardDataSpecialization
,
ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
...
@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -116,7 +116,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// A: output tensor
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
...
@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -336,7 +336,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// A: output tensor
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
...
@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -618,7 +618,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// A: output tensor
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
...
@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
// function end
}
// function end
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
static
auto
GetABCGridDesc
()
{
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
0
});
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
0
});
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
static
auto
GetABCGridDesc
()
{
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
});
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
});
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
static
auto
GetABCGridDesc
()
{
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
1
,
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
1
,
...
@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -959,7 +959,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -1385,7 +1385,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 pad = 0 conv
// check if it's 1x1, stride=1 pad = 0 conv
for
(
int
i
=
0
;
i
<
NumDimSpatial
;
i
++
)
for
(
int
i
=
0
;
i
<
NumDimSpatial
;
i
++
)
...
@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
...
@@ -1527,7 +1527,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
<<
K0PerBlock
<<
K0PerBlock
<<
">"
;
<<
">"
;
if
constexpr
(
ConvBackwardDataSpecialization
==
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
_t
::
Filter1x1Stride1Pad0
){
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
){
str
<<
" Filter1x1Stride1Pad0"
;
str
<<
" Filter1x1Stride1Pad0"
;
}
}
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
07a673c6
...
@@ -44,7 +44,7 @@ template <typename InDataType,
...
@@ -44,7 +44,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization
_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
...
@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -142,7 +142,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
...
@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -156,7 +156,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_wi_c_grid_desc
=
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
...
@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -262,7 +262,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
...
@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -276,7 +276,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_hi_wi_c_grid_desc
=
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
...
@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -395,7 +395,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmmraw_gemmk_grid_desc
=
const
auto
in_gemmmraw_gemmk_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
...
@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -409,7 +409,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
const
auto
in_n_di_hi_wi_c_grid_desc
=
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
...
@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -613,7 +613,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -878,7 +878,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
...
@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -891,7 +891,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
{
// check if it's 1x1 conv
// check if it's 1x1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
for
(
ck
::
index_t
i
=
0
;
i
<
NumDimSpatial
;
++
i
)
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
07a673c6
...
@@ -29,7 +29,7 @@ template <typename ALayout,
...
@@ -29,7 +29,7 @@ template <typename ALayout,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ReduceOperation
,
GemmSpecialization
_t
GemmSpec
ialization
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -95,8 +95,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad both M and K
// pad both M and K
assert
(
K
%
AK1
==
0
);
assert
(
K
%
AK1
==
0
);
...
@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -119,8 +119,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
// pad M, but not K
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
assert
(
KRaw
%
AK1
==
0
);
...
@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -136,8 +136,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
KPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
NKPadding
)
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
{
// pad K, but not M
// pad K, but not M
assert
(
K
%
AK1
==
0
);
assert
(
K
%
AK1
==
0
);
...
@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -198,8 +198,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
auto
NPad
=
N
-
NRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
NKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad both N and K
// pad both N and K
assert
(
K
%
BK1
==
0
);
assert
(
K
%
BK1
==
0
);
...
@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -222,8 +222,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
NPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
// pad N, but not K
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
assert
(
KRaw
%
BK1
==
0
);
...
@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -239,8 +239,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
KPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
)
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
{
// pad K, but not N
// pad K, but not N
assert
(
K
%
BK1
==
0
);
assert
(
K
%
BK1
==
0
);
...
@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -301,8 +301,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad M and N
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
...
@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -311,8 +311,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
)
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
{
// pad M, but not N
// pad M, but not N
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -321,8 +321,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
NPadding
||
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
NKPadding
)
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
{
// pad N, but not M
// pad N, but not M
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -346,10 +346,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad M
// pad M
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
...
@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
...
@@ -382,8 +382,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation
,
CElementwiseOperation
,
D0ReduceOperation
,
D0ReduceOperation
,
D1ReduceOperation
,
D1ReduceOperation
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
_t
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
07a673c6
...
@@ -27,7 +27,7 @@ template <typename ADataType,
...
@@ -27,7 +27,7 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
_t
GemmSpec
ialization
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -80,7 +80,7 @@ struct DeviceGemmXdl
...
@@ -80,7 +80,7 @@ struct DeviceGemmXdl
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
...
@@ -119,7 +119,7 @@ struct DeviceGemmXdl
...
@@ -119,7 +119,7 @@ struct DeviceGemmXdl
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -154,7 +154,7 @@ struct DeviceGemmXdl
...
@@ -154,7 +154,7 @@ struct DeviceGemmXdl
}
}
}();
}();
if
constexpr
(
GemmSpec
ialization
==
GemmSpecialization
_t
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -186,7 +186,7 @@ struct DeviceGemmXdl
...
@@ -186,7 +186,7 @@ struct DeviceGemmXdl
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
View file @
07a673c6
...
@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
...
@@ -138,7 +138,7 @@ struct DeviceGemmXdl_C_Shuffle
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
View file @
07a673c6
...
@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
...
@@ -139,7 +139,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
View file @
07a673c6
...
@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
...
@@ -147,7 +147,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
View file @
07a673c6
...
@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
...
@@ -169,7 +169,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
...
...
Prev
1
2
3
4
5
6
7
8
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment