Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4698993d
Unverified
Commit
4698993d
authored
Nov 15, 2022
by
Po Yen Chen
Committed by
GitHub
Nov 15, 2022
Browse files
Merge branch 'develop' into wmma_op
parents
ab663329
7038723a
Changes
202
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
737 additions
and
506 deletions
+737
-506
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
...e_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
+229
-90
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+7
-8
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+14
-3
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+1
-1
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+30
-44
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+8
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+4
-3
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
...or_operation_instance/gpu/convolution_backward_weight.hpp
+0
-230
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+44
-36
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+235
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
...operation_instance/gpu/grouped_convolution_forward_dl.hpp
+1
-3
library/include/ck/library/utility/algorithm.hpp
library/include/ck/library/utility/algorithm.hpp
+43
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+46
-43
library/include/ck/library/utility/convolution_parameter.hpp
library/include/ck/library/utility/convolution_parameter.hpp
+6
-8
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+14
-3
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+28
-26
library/include/ck/library/utility/iterator.hpp
library/include/ck/library/utility/iterator.hpp
+22
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_conv
nd
_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_
grouped_
conv_bwd_weight_
g
nwc_
g
kxc_
g
nwk_xdl_cshuffle.hpp
View file @
4698993d
...
...
@@ -4,13 +4,14 @@
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/host_utility/device_prop.hpp"
...
...
@@ -20,6 +21,108 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_xdlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
batch_count
;
ignore
=
block_2_ctile_map
;
ignore
=
compute_ptr_offset_of_batch
;
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
0
);
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
...
...
@@ -57,21 +160,21 @@ template <ck::index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle
:
public
DeviceConvBwdWeight
<
struct
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle
:
public
Device
Grouped
ConvBwdWeight
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
NWC
,
ck
::
tensor_layout
::
convolution
::
G
NHWC
,
ck
::
tensor_layout
::
convolution
::
G
NDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
KXC
,
ck
::
tensor_layout
::
convolution
::
G
KYXC
,
ck
::
tensor_layout
::
convolution
::
G
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
NWK
,
ck
::
tensor_layout
::
convolution
::
G
NHWK
,
ck
::
tensor_layout
::
convolution
::
G
NDHWK
>>
,
InDataType
,
WeiDataType
,
OutDataType
,
...
...
@@ -79,7 +182,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle
;
using
DeviceOp
=
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
...
...
@@ -117,18 +220,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
batch_k
)
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -269,18 +372,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
batch_k
)
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -436,18 +539,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
batch_k
)
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -664,8 +767,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
template
<
index_t
Dim
>
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
static
auto
MakeDescriptor_M0
(
const
std
::
array
<
index_t
,
Dim
>&
shape
,
const
std
::
array
<
index_t
,
Dim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
...
...
@@ -759,16 +862,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
...
...
@@ -783,11 +887,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
G
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
...
...
@@ -819,6 +925,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
...
...
@@ -836,21 +962,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
InElementwiseOperation
a_element_op_
;
OutElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
};
...
...
@@ -873,14 +1007,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{
"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
...
...
@@ -891,7 +1023,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
...
...
@@ -900,17 +1032,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
const
auto
kernel
=
kernel_
batched_
gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -921,13 +1054,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
has_main_k0_block_loop
)
...
...
@@ -998,16 +1133,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1016,6 +1152,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
G
,
N
,
K
,
C
,
...
...
@@ -1040,16 +1177,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1058,6 +1196,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
N
,
K
,
C
,
...
...
@@ -1086,7 +1225,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle"
str
<<
"Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
4698993d
...
...
@@ -22,6 +22,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/numeric.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -410,10 +411,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
{
const
index_t
N
=
r_g_n_wos_lengths
[
1
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
r_g_n_wos_lengths
.
begin
()
+
2
,
r_g_n_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
r_g_n_wos_lengths
.
begin
()
+
2
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
r_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
));
...
...
@@ -435,10 +435,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const
index_t
WoStride
=
r_g_n_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
r_g_n_wos_lengths
.
begin
()
+
2
,
r_g_n_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
r_g_n_wos_lengths
.
begin
()
+
2
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
r_grid_desc_mraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
),
make_tuple
(
WoStride
));
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
4698993d
...
...
@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
index_t
M01
=
1
,
index_t
N01
=
1
,
index_t
KSplit
=
1
)
:
M01_
(
M01
),
:
c_grid_desc_m_n_
(
c_grid_desc_m_n
),
M01_
(
M01
),
N01_
(
N01
),
KSplit_
(
KSplit
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
,
N01
,
KSplit
))
{
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
__device__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
...
...
@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
underlying_map_
.
CalculateBottomIndex
(
idx_top
);
static_assert
(
TopIdx
::
Size
()
==
1
);
return
underlying_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
I0
]
%
CalculateGridSize
()));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
...
@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
}
private:
__device__
constexpr
index_t
CalculateGridSize
()
const
{
return
CalculateGridSize
(
c_grid_desc_m_n_
);
}
__host__
static
constexpr
auto
GetBlockToCTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
,
...
...
@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return
c_blockid_to_ksplit_m0_n0_block_cluster_adaptor
;
}
CGridDesc_M_N
c_grid_desc_m_n_
;
index_t
M01_
,
N01_
,
KSplit_
;
using
UnderlyingMap
=
decltype
(
GetBlockToCTileMap
(
CGridDesc_M_N
{},
1
,
1
,
1
));
UnderlyingMap
underlying_map_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
4698993d
...
...
@@ -289,7 +289,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
XDataType
,
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
YElementwiseOperation
,
PassThrough
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
4698993d
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NWo
,
C
));
...
...
@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
C
));
...
...
@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo
,
C
));
...
...
@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
...
...
@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
...
...
@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
...
...
@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
...
...
@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
...
...
@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
...
...
@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
...
...
@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
4698993d
...
...
@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
else
if
constexpr
(
NDimSpatial
==
2
)
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
std
::
size_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
]
;
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
3
]
;
++
ho
)
for
(
std
::
size_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
4
]
;
++
wo
)
for
(
std
::
size_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
4698993d
...
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
Tensor
<
ComputeDataType
>
avg_acc_sq
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc_sq
({
M
});
Tensor
<
ComputeDataType
>
avg_acc
({
M
});
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
// reduce N dim
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
4698993d
...
...
@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
}
}
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
4698993d
...
...
@@ -95,7 +95,7 @@ template <typename Activation>
using
Add_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
DeviceOp
>
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceFactory
;
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
deleted
100644 → 0
View file @
ab663329
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
View file @
4698993d
...
...
@@ -5,7 +5,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data
_multiple_d
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
...
@@ -17,46 +17,54 @@ namespace instance {
// conv2d backward data
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdData
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiLayout
,
typename
InLayout
,
typename
OutDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdData
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
typename
InDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Empty_Tuple
,
InLayout
,
OutDataType
,
WeiDataType
,
Empty_Tuple
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdData
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Empty_Tuple
,
InLayout
,
OutDataType
,
WeiDataType
,
Empty_Tuple
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
0 → 100644
View file @
4698993d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
4698993d
...
...
@@ -3,11 +3,11 @@
#pragma once
#include <
cstdlib
>
#include <
vector
>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
View file @
4698993d
...
...
@@ -3,11 +3,9 @@
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
...
library/include/ck/library/utility/algorithm.hpp
0 → 100644
View file @
4698993d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <iterator>
#include <type_traits>
#include <utility>
namespace
ck
{
namespace
ranges
{
template
<
typename
InputRange
,
typename
OutputIterator
>
auto
copy
(
InputRange
&&
range
,
OutputIterator
iter
)
->
decltype
(
std
::
copy
(
std
::
begin
(
std
::
forward
<
InputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
InputRange
>
(
range
)),
iter
))
{
return
std
::
copy
(
std
::
begin
(
std
::
forward
<
InputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
InputRange
>
(
range
)),
iter
);
}
template
<
typename
T
,
typename
OutputRange
>
auto
fill
(
OutputRange
&&
range
,
const
T
&
init
)
->
std
::
void_t
<
decltype
(
std
::
fill
(
std
::
begin
(
std
::
forward
<
OutputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
OutputRange
>
(
range
)),
init
))
>
{
std
::
fill
(
std
::
begin
(
std
::
forward
<
OutputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
OutputRange
>
(
range
)),
init
);
}
template
<
typename
InputRange
,
typename
OutputIterator
,
typename
UnaryOperation
>
auto
transform
(
InputRange
&&
range
,
OutputIterator
iter
,
UnaryOperation
unary_op
)
->
decltype
(
std
::
transform
(
std
::
begin
(
range
),
std
::
end
(
range
),
iter
,
unary_op
))
{
return
std
::
transform
(
std
::
begin
(
range
),
std
::
end
(
range
),
iter
,
unary_op
);
}
}
// namespace ranges
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
4698993d
...
...
@@ -15,18 +15,22 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/ranges.hpp"
namespace
ck
{
namespace
utils
{
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
&&
!
std
::
is_same
<
T
,
half_t
>::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>&
out
,
const
std
::
vector
<
T
>&
ref
,
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_floating_point_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-5
,
double
atol
=
3e-6
)
...
...
@@ -44,15 +48,17 @@ check_err(const std::vector<T>& out,
double
max_err
=
std
::
numeric_limits
<
double
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
err
=
std
::
abs
(
out
[
i
]
-
ref
[
i
]);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
ref
[
i
])
||
!
std
::
isfinite
(
out
[
i
])
||
!
std
::
isfinite
(
ref
[
i
]))
const
double
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
double
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
ut
[
i
]
<<
" != "
<<
r
ef
[
i
]
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
...
...
@@ -64,10 +70,13 @@ check_err(const std::vector<T>& out,
return
res
;
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
bhalf_t
>::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>&
out
,
const
std
::
vector
<
T
>&
ref
,
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
...
...
@@ -86,9 +95,9 @@ check_err(const std::vector<T>& out,
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
double
o
=
type_convert
<
float
>
(
out
[
i
]
);
double
r
=
type_convert
<
float
>
(
ref
[
i
]
);
err
=
std
::
abs
(
o
-
r
);
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
)
);
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
)
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
...
...
@@ -108,10 +117,13 @@ check_err(const std::vector<T>& out,
return
res
;
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same_v
<
T
,
half_t
>
,
bool
>::
type
check_err
(
span
<
const
T
>
out
,
span
<
const
T
>
ref
,
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
...
...
@@ -126,12 +138,12 @@ check_err(span<const T> out,
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
T
>::
min
();
double
max_err
=
std
::
numeric_limits
<
ranges
::
range_value_t
<
Range
>
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
double
o
=
type_convert
<
float
>
(
out
[
i
]
);
double
r
=
type_convert
<
float
>
(
ref
[
i
]
);
err
=
std
::
abs
(
o
-
r
);
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
)
);
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
)
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
...
...
@@ -151,26 +163,17 @@ check_err(span<const T> out,
return
res
;
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>&
out
,
const
std
::
vector
<
T
>&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
return
check_err
(
span
<
const
T
>
{
out
},
span
<
const
T
>
{
ref
},
msg
,
rtol
,
atol
);
}
template
<
typename
T
>
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
int4_t
>
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
,
bool
>
check_err
(
const
std
::
vector
<
T
>
&
out
,
const
std
::
vector
<
T
>
&
ref
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
=
0
,
double
atol
=
0
)
...
...
@@ -188,9 +191,9 @@ check_err(const std::vector<T>& out,
int64_t
max_err
=
std
::
numeric_limits
<
int64_t
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
int64_t
o
=
out
[
i
]
;
int64_t
r
=
ref
[
i
]
;
err
=
std
::
abs
(
o
-
r
);
const
int64_t
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
)
;
const
int64_t
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
)
;
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
)
{
...
...
library/include/ck/library/utility/convolution_parameter.hpp
View file @
4698993d
...
...
@@ -10,6 +10,8 @@
#include "ck/ck.hpp"
#include "ck/library/utility/numeric.hpp"
namespace
ck
{
namespace
utils
{
namespace
conv
{
...
...
@@ -55,10 +57,8 @@ struct ConvParam
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G_
*
N_
*
C_
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
std
::
begin
(
input_spatial_lengths_
)
+
num_dim_spatial_
,
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
ck
::
accumulate_n
<
std
::
size_t
>
(
std
::
begin
(
input_spatial_lengths_
),
num_dim_spatial_
,
1
,
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
>
...
...
@@ -67,10 +67,8 @@ struct ConvParam
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
begin
(
filter_spatial_lengths_
)
+
num_dim_spatial_
,
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
ck
::
accumulate_n
<
std
::
size_t
>
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
,
1
,
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
>
...
...
library/include/ck/library/utility/fill.hpp
View file @
4698993d
...
...
@@ -30,9 +30,10 @@ struct FillUniformDistribution
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
std
::
declval
<
FillUniformDistribution
>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillUniformDistribution
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
...
...
@@ -72,6 +73,16 @@ struct FillUniformDistributionIntegerValue
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
std
::
round
(
dis
(
gen
)));
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillUniformDistributionIntegerValue
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
4698993d
...
...
@@ -14,6 +14,9 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/ranges.hpp"
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
{
...
...
@@ -84,10 +87,10 @@ struct HostTensorDescriptor
this
->
CalculateStrides
();
}
template
<
typename
Range
,
template
<
typename
Lengths
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
decltype
(
*
std
::
begin
(
std
::
declval
<
Range
>()))
,
std
::
size_t
>>>
HostTensorDescriptor
(
const
Range
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>
,
std
::
size_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
...
...
@@ -102,13 +105,12 @@ struct HostTensorDescriptor
{
}
template
<
typename
Range1
,
typename
Range2
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
decltype
(
*
std
::
begin
(
std
::
declval
<
Range1
>())),
std
::
size_t
>
&&
std
::
is_convertible_v
<
decltype
(
*
std
::
begin
(
std
::
declval
<
Range2
>
())),
std
::
size_t
>>>
HostTensorDescriptor
(
const
Range1
&
lens
,
const
Range2
&
strides
)
template
<
typename
Lengths
,
typename
Strides
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
}
...
...
@@ -244,14 +246,20 @@ struct Tensor
{
}
template
<
typename
X
>
Tensor
(
std
::
vector
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
}
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
vector
<
X
>
lens
,
std
::
vector
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
template
<
typename
Lengths
>
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
}
template
<
typename
Lengths
,
typename
Strides
>
Tensor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
GetElementSpaceSize
())
{
}
...
...
@@ -261,10 +269,10 @@ struct Tensor
Tensor
<
OutT
>
CopyAsType
()
const
{
Tensor
<
OutT
>
ret
(
mDesc
);
for
(
size_t
i
=
0
;
i
<
mData
.
size
();
i
++
)
{
ret
.
mData
[
i
]
=
ck
::
type_convert
<
OutT
>
(
mData
[
i
]
);
}
ck
::
ranges
::
transform
(
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
return
ck
::
type_convert
<
OutT
>
(
value
);
}
);
return
ret
;
}
...
...
@@ -294,13 +302,7 @@ struct Tensor
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
{
for
(
auto
&
v
:
mData
)
{
v
=
T
{
0
};
}
}
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
0
);
}
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
...
...
library/include/ck/library/utility/iterator.hpp
0 → 100644
View file @
4698993d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iterator>
#include <utility>
#include "ck/utility/type.hpp"
namespace
ck
{
template
<
typename
T
>
using
iter_value_t
=
typename
std
::
iterator_traits
<
remove_cvref_t
<
T
>>::
value_type
;
template
<
typename
T
>
using
iter_reference_t
=
decltype
(
*
std
::
declval
<
T
&>
());
template
<
typename
T
>
using
iter_difference_t
=
typename
std
::
iterator_traits
<
remove_cvref_t
<
T
>>::
difference_type
;
}
// namespace ck
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment