Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f0224f2a
Commit
f0224f2a
authored
Nov 29, 2022
by
letaoqin
Browse files
Merge branch 'develop' into dl_conv_multiple_d
parents
befc2638
0e9c88ce
Changes
271
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1105 additions
and
359 deletions
+1105
-359
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+4
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+7
-5
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+3
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+56
-0
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
...brary/tensor_operation_instance/gpu/batchnorm_forward.hpp
+117
-0
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+39
-0
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
...or_operation_instance/gpu/convolution_backward_weight.hpp
+0
-230
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
...brary/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
+145
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
...k/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
+138
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+44
-36
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+235
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
...operation_instance/gpu/grouped_convolution_forward_dl.hpp
+116
-0
library/include/ck/library/utility/algorithm.hpp
library/include/ck/library/utility/algorithm.hpp
+43
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+46
-43
library/include/ck/library/utility/convolution_parameter.hpp
library/include/ck/library/utility/convolution_parameter.hpp
+6
-8
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+14
-3
library/include/ck/library/utility/host_common_util.hpp
library/include/ck/library/utility/host_common_util.hpp
+60
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+28
-26
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
f0224f2a
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
Tensor
<
ComputeDataType
>
avg_acc_sq
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc_sq
({
M
});
Tensor
<
ComputeDataType
>
avg_acc
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc
({
M
});
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
// reduce N dim
// reduce N dim
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
f0224f2a
...
@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -92,9 +92,10 @@ struct ReferenceLayernorm : public device::BaseOperator
{
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
}
}
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
f0224f2a
...
@@ -86,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -86,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
};
};
arg
.
in_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
in_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
reduce_max
(
to_sm_scalar_idx
(
idx
))
=
std
::
max
(
reduce_max
(
to_sm_scalar_idx
(
idx
)),
reduce_max
(
to_sm_scalar_idx
(
idx
))
=
std
::
max
(
static_cas
t
<
AccDataType
>
(
self
(
idx
)));
reduce_max
(
to_sm_scalar_idx
(
idx
)),
ck
::
type_conver
t
<
AccDataType
>
(
self
(
idx
)));
});
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
...
@@ -96,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -96,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
AccDataType
>
in_stable
(
arg
.
in_
.
mDesc
);
Tensor
<
AccDataType
>
in_stable
(
arg
.
in_
.
mDesc
);
in_stable
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
in_stable
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// numerator = exp(x - max(x))
// numerator = exp(x - max(x))
self
(
idx
)
=
std
::
exp
(
static_cas
t
<
AccDataType
>
(
arg
.
in_
(
idx
))
-
self
(
idx
)
=
std
::
exp
(
ck
::
type_conver
t
<
AccDataType
>
(
arg
.
in_
(
idx
))
-
reduce_max
(
to_sm_scalar_idx
(
idx
)));
reduce_max
(
to_sm_scalar_idx
(
idx
)));
});
});
...
@@ -111,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -111,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
// std::endl;
// std::endl;
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
alpha_
*
in_stable
(
idx
)
/
reduce_sum
(
to_sm_scalar_idx
(
idx
))
+
AccDataType
temp_result
=
arg
.
beta_
*
self
(
idx
);
arg
.
alpha_
*
in_stable
(
idx
)
/
reduce_sum
(
to_sm_scalar_idx
(
idx
))
+
arg
.
beta_
*
self
(
idx
);
self
(
idx
)
=
ck
::
type_convert
<
OutDataType
>
(
temp_result
);
});
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
f0224f2a
...
@@ -87,6 +87,8 @@ using Relu = ck::tensor_operation::element_wise::Relu;
...
@@ -87,6 +87,8 @@ using Relu = ck::tensor_operation::element_wise::Relu;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
template
<
typename
Activation
>
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
...
@@ -95,7 +97,7 @@ template <typename Activation>
...
@@ -95,7 +97,7 @@ template <typename Activation>
using
Add_Activation_Mul_Clamp
=
using
Add_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
DeviceOp
>
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceFactory
;
struct
DeviceOperationInstanceFactory
;
}
// namespace instance
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
f0224f2a
...
@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
...
@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
...
@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
...
@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
0 → 100644
View file @
f0224f2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP32
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// BF16
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP64
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
BiasDataType
,
F16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
f0224f2a
...
@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
// conv2d dl
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward data
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
...
@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
deleted
100644 → 0
View file @
befc2638
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
0 → 100644
View file @
f0224f2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
// GEMM + Add + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddFastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
0 → 100644
View file @
f0224f2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
// GEMM + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
View file @
f0224f2a
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data
_multiple_d
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
@@ -17,46 +17,54 @@ namespace instance {
...
@@ -17,46 +17,54 @@ namespace instance {
// conv2d backward data
// conv2d backward data
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWC
,
GNHWK
,
GKYXC
,
GKYXC
,
GNHWK
,
Empty_Tuple
,
F16
,
GNHWC
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
Empty_Tuple
,
PassThrough
,
F16
,
PassThrough
>>>&
instances
);
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiLayout
,
typename
InLayout
,
typename
OutDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
>
typename
InDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdData
<
struct
DeviceOperationInstanceFactory
<
NumDimSpatial
,
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdDataMultipleD
<
InLayout
,
NumDimSpatial
,
WeiLayout
,
OutLayout
,
OutLayout
,
WeiLayout
,
InDataType
,
Empty_Tuple
,
WeiDataType
,
InLayout
,
OutDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
WeiDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Empty_Tuple
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
{
using
DeviceOp
=
DeviceGroupedConvBwdData
<
NumDimSpatial
,
using
DeviceOp
=
InLayout
,
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
WeiLayout
,
OutLayout
,
OutLayout
,
WeiLayout
,
InDataType
,
Empty_Tuple
,
WeiDataType
,
InLayout
,
OutDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
WeiDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Empty_Tuple
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
0 → 100644
View file @
f0224f2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
f0224f2a
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
#pragma once
#pragma once
#include <
cstdlib
>
#include <
vector
>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
0 → 100644
View file @
f0224f2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/algorithm.hpp
0 → 100644
View file @
f0224f2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <iterator>
#include <type_traits>
#include <utility>
namespace
ck
{
namespace
ranges
{
template
<
typename
InputRange
,
typename
OutputIterator
>
auto
copy
(
InputRange
&&
range
,
OutputIterator
iter
)
->
decltype
(
std
::
copy
(
std
::
begin
(
std
::
forward
<
InputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
InputRange
>
(
range
)),
iter
))
{
return
std
::
copy
(
std
::
begin
(
std
::
forward
<
InputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
InputRange
>
(
range
)),
iter
);
}
template
<
typename
T
,
typename
OutputRange
>
auto
fill
(
OutputRange
&&
range
,
const
T
&
init
)
->
std
::
void_t
<
decltype
(
std
::
fill
(
std
::
begin
(
std
::
forward
<
OutputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
OutputRange
>
(
range
)),
init
))
>
{
std
::
fill
(
std
::
begin
(
std
::
forward
<
OutputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
OutputRange
>
(
range
)),
init
);
}
template
<
typename
InputRange
,
typename
OutputIterator
,
typename
UnaryOperation
>
auto
transform
(
InputRange
&&
range
,
OutputIterator
iter
,
UnaryOperation
unary_op
)
->
decltype
(
std
::
transform
(
std
::
begin
(
range
),
std
::
end
(
range
),
iter
,
unary_op
))
{
return
std
::
transform
(
std
::
begin
(
range
),
std
::
end
(
range
),
iter
,
unary_op
);
}
}
// namespace ranges
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
f0224f2a
...
@@ -15,18 +15,22 @@
...
@@ -15,18 +15,22 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/ranges.hpp"
namespace
ck
{
namespace
ck
{
namespace
utils
{
namespace
utils
{
template
<
typename
T
>
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
&&
!
std
::
is_same
<
T
,
half_t
>::
value
,
typename
std
::
enable_if
<
bool
>::
type
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
check_err
(
const
std
::
vector
<
T
>&
out
,
std
::
is_floating_point_v
<
ranges
::
range_value_t
<
Range
>>
&&
const
std
::
vector
<
T
>&
ref
,
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-5
,
double
rtol
=
1e-5
,
double
atol
=
3e-6
)
double
atol
=
3e-6
)
...
@@ -44,15 +48,17 @@ check_err(const std::vector<T>& out,
...
@@ -44,15 +48,17 @@ check_err(const std::vector<T>& out,
double
max_err
=
std
::
numeric_limits
<
double
>::
min
();
double
max_err
=
std
::
numeric_limits
<
double
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
err
=
std
::
abs
(
out
[
i
]
-
ref
[
i
]);
const
double
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
ref
[
i
])
||
!
std
::
isfinite
(
out
[
i
])
||
!
std
::
isfinite
(
ref
[
i
]))
const
double
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
ut
[
i
]
<<
" != "
<<
r
ef
[
i
]
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
}
}
...
@@ -64,10 +70,13 @@ check_err(const std::vector<T>& out,
...
@@ -64,10 +70,13 @@ check_err(const std::vector<T>& out,
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
bhalf_t
>::
value
,
bool
>::
type
typename
std
::
enable_if
<
check_err
(
const
std
::
vector
<
T
>&
out
,
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
const
std
::
vector
<
T
>&
ref
,
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
...
@@ -86,9 +95,9 @@ check_err(const std::vector<T>& out,
...
@@ -86,9 +95,9 @@ check_err(const std::vector<T>& out,
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
double
o
=
type_convert
<
float
>
(
out
[
i
]
);
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
)
);
double
r
=
type_convert
<
float
>
(
ref
[
i
]
);
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
)
);
err
=
std
::
abs
(
o
-
r
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
...
@@ -108,10 +117,13 @@ check_err(const std::vector<T>& out,
...
@@ -108,10 +117,13 @@ check_err(const std::vector<T>& out,
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
T
,
half_t
>
,
bool
>::
type
typename
std
::
enable_if
<
check_err
(
span
<
const
T
>
out
,
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
span
<
const
T
>
ref
,
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
...
@@ -126,12 +138,12 @@ check_err(span<const T> out,
...
@@ -126,12 +138,12 @@ check_err(span<const T> out,
bool
res
{
true
};
bool
res
{
true
};
int
err_count
=
0
;
int
err_count
=
0
;
double
err
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
T
>::
min
();
double
max_err
=
std
::
numeric_limits
<
ranges
::
range_value_t
<
Range
>
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
double
o
=
type_convert
<
float
>
(
out
[
i
]
);
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
)
);
double
r
=
type_convert
<
float
>
(
ref
[
i
]
);
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
)
);
err
=
std
::
abs
(
o
-
r
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
...
@@ -151,26 +163,17 @@ check_err(span<const T> out,
...
@@ -151,26 +163,17 @@ check_err(span<const T> out,
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>::
value
,
bool
>::
type
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
check_err
(
const
std
::
vector
<
T
>&
out
,
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
const
std
::
vector
<
T
>&
ref
,
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bhalf_t
>
)
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
return
check_err
(
span
<
const
T
>
{
out
},
span
<
const
T
>
{
ref
},
msg
,
rtol
,
atol
);
}
template
<
typename
T
>
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
int4_t
>
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
#endif
,
,
bool
>
bool
>
check_err
(
const
std
::
vector
<
T
>
&
out
,
check_err
(
const
Range
&
out
,
const
std
::
vector
<
T
>
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
=
0
,
double
=
0
,
double
atol
=
0
)
double
atol
=
0
)
...
@@ -188,9 +191,9 @@ check_err(const std::vector<T>& out,
...
@@ -188,9 +191,9 @@ check_err(const std::vector<T>& out,
int64_t
max_err
=
std
::
numeric_limits
<
int64_t
>::
min
();
int64_t
max_err
=
std
::
numeric_limits
<
int64_t
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
{
int64_t
o
=
out
[
i
]
;
const
int64_t
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
)
;
int64_t
r
=
ref
[
i
]
;
const
int64_t
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
)
;
err
=
std
::
abs
(
o
-
r
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
)
if
(
err
>
atol
)
{
{
...
...
library/include/ck/library/utility/convolution_parameter.hpp
View file @
f0224f2a
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/utility/numeric.hpp"
namespace
ck
{
namespace
ck
{
namespace
utils
{
namespace
utils
{
namespace
conv
{
namespace
conv
{
...
@@ -55,10 +57,8 @@ struct ConvParam
...
@@ -55,10 +57,8 @@ struct ConvParam
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
return
sizeof
(
InDataType
)
*
(
G_
*
N_
*
C_
*
(
G_
*
N_
*
C_
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
ck
::
accumulate_n
<
std
::
size_t
>
(
std
::
begin
(
input_spatial_lengths_
)
+
num_dim_spatial_
,
std
::
begin
(
input_spatial_lengths_
),
num_dim_spatial_
,
1
,
std
::
multiplies
<>
()));
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
}
template
<
typename
WeiDataType
>
template
<
typename
WeiDataType
>
...
@@ -67,10 +67,8 @@ struct ConvParam
...
@@ -67,10 +67,8 @@ struct ConvParam
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
return
sizeof
(
WeiDataType
)
*
(
G_
*
K_
*
C_
*
(
G_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
ck
::
accumulate_n
<
std
::
size_t
>
(
std
::
begin
(
filter_spatial_lengths_
)
+
num_dim_spatial_
,
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
,
1
,
std
::
multiplies
<>
()));
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
}
template
<
typename
OutDataType
>
template
<
typename
OutDataType
>
...
...
library/include/ck/library/utility/fill.hpp
View file @
f0224f2a
...
@@ -30,9 +30,10 @@ struct FillUniformDistribution
...
@@ -30,9 +30,10 @@ struct FillUniformDistribution
}
}
template
<
typename
ForwardRange
>
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
auto
operator
()(
ForwardRange
&&
range
)
const
std
::
declval
<
FillUniformDistribution
>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillUniformDistribution
&>
()(
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
...
@@ -72,6 +73,16 @@ struct FillUniformDistributionIntegerValue
...
@@ -72,6 +73,16 @@ struct FillUniformDistributionIntegerValue
std
::
generate
(
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
std
::
round
(
dis
(
gen
)));
});
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
std
::
round
(
dis
(
gen
)));
});
}
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillUniformDistributionIntegerValue
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
library/include/ck/library/utility/host_common_util.hpp
View file @
f0224f2a
...
@@ -4,9 +4,11 @@
...
@@ -4,9 +4,11 @@
#pragma once
#pragma once
#include <vector>
#include <vector>
#include <array>
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <string>
#include <string>
#include <algorithm>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return
(
values
);
return
(
values
);
}
}
template
<
int
NDim
>
static
inline
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
get_index_set
(
const
std
::
array
<
index_t
,
NDim
>&
dim_lengths
)
{
static_assert
(
NDim
>=
1
,
"NDim >= 1 is required to use this function!"
);
if
constexpr
(
NDim
==
1
)
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
for
(
int
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
{
std
::
array
<
index_t
,
1
>
index
{
i
};
index_set
.
push_back
(
index
);
};
return
index_set
;
}
else
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
std
::
array
<
index_t
,
NDim
-
1
>
partial_dim_lengths
;
std
::
copy
(
dim_lengths
.
begin
()
+
1
,
dim_lengths
.
end
(),
partial_dim_lengths
.
begin
());
std
::
vector
<
std
::
array
<
index_t
,
NDim
-
1
>>
partial_index_set
;
partial_index_set
=
get_index_set
<
NDim
-
1
>
(
partial_dim_lengths
);
for
(
index_t
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
for
(
const
auto
&
partial_index
:
partial_index_set
)
{
std
::
array
<
index_t
,
NDim
>
index
;
index
[
0
]
=
i
;
std
::
copy
(
partial_index
.
begin
(),
partial_index
.
end
(),
index
.
begin
()
+
1
);
index_set
.
push_back
(
index
);
};
return
index_set
;
};
};
template
<
int
NDim
>
static
inline
size_t
get_offset_from_index
(
const
std
::
array
<
index_t
,
NDim
>&
strides
,
const
std
::
array
<
index_t
,
NDim
>&
index
)
{
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
NDim
;
i
++
)
offset
+=
index
[
i
]
*
strides
[
i
];
return
(
offset
);
};
}
// namespace host_common
}
// namespace host_common
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/host_tensor.hpp
View file @
f0224f2a
...
@@ -14,6 +14,9 @@
...
@@ -14,6 +14,9 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/span.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/ranges.hpp"
template
<
typename
Range
>
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
{
{
...
@@ -84,10 +87,10 @@ struct HostTensorDescriptor
...
@@ -84,10 +87,10 @@ struct HostTensorDescriptor
this
->
CalculateStrides
();
this
->
CalculateStrides
();
}
}
template
<
typename
Range
,
template
<
typename
Lengths
,
typename
=
std
::
enable_if_t
<
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
decltype
(
*
std
::
begin
(
std
::
declval
<
Range
>()))
,
std
::
size_t
>>>
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>
,
std
::
size_t
>>>
HostTensorDescriptor
(
const
Range
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
HostTensorDescriptor
(
const
Lengths
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
{
this
->
CalculateStrides
();
this
->
CalculateStrides
();
}
}
...
@@ -102,13 +105,12 @@ struct HostTensorDescriptor
...
@@ -102,13 +105,12 @@ struct HostTensorDescriptor
{
{
}
}
template
<
template
<
typename
Lengths
,
typename
Range1
,
typename
Strides
,
typename
Range2
,
typename
=
std
::
enable_if_t
<
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
decltype
(
*
std
::
begin
(
std
::
declval
<
Range1
>())),
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>>>
std
::
is_convertible_v
<
decltype
(
*
std
::
begin
(
std
::
declval
<
Range2
>
())),
std
::
size_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
HostTensorDescriptor
(
const
Range1
&
lens
,
const
Range2
&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
{
}
}
...
@@ -244,14 +246,20 @@ struct Tensor
...
@@ -244,14 +246,20 @@ struct Tensor
{
{
}
}
template
<
typename
X
>
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
vector
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
{
}
}
template
<
typename
X
,
typename
Y
>
template
<
typename
Lengths
>
Tensor
(
std
::
vector
<
X
>
lens
,
std
::
vector
<
Y
>
strides
)
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
}
template
<
typename
Lengths
,
typename
Strides
>
Tensor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
GetElementSpaceSize
())
{
{
}
}
...
@@ -261,10 +269,10 @@ struct Tensor
...
@@ -261,10 +269,10 @@ struct Tensor
Tensor
<
OutT
>
CopyAsType
()
const
Tensor
<
OutT
>
CopyAsType
()
const
{
{
Tensor
<
OutT
>
ret
(
mDesc
);
Tensor
<
OutT
>
ret
(
mDesc
);
for
(
size_t
i
=
0
;
i
<
mData
.
size
();
i
++
)
{
ck
::
ranges
::
transform
(
ret
.
mData
[
i
]
=
ck
::
type_convert
<
OutT
>
(
mData
[
i
]
);
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
return
ck
::
type_convert
<
OutT
>
(
value
);
}
);
}
return
ret
;
return
ret
;
}
}
...
@@ -294,13 +302,7 @@ struct Tensor
...
@@ -294,13 +302,7 @@ struct Tensor
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
void
SetZero
()
{
ck
::
ranges
::
fill
<
T
>
(
mData
,
0
);
}
{
for
(
auto
&
v
:
mData
)
{
v
=
T
{
0
};
}
}
template
<
typename
F
>
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment