Commit a30c626b authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Make ComputeDataType an optional argument

parent b019d839
......@@ -30,7 +30,6 @@ using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -141,7 +140,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<DDataType>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
......
......@@ -30,7 +30,6 @@ using CShuffleDataType = F64;
using DDataType = F64;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -186,7 +185,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<DDataType>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
......
......@@ -29,7 +29,6 @@ using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -128,7 +127,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>;
......
......@@ -29,7 +29,6 @@ using AccDataType = F64;
using CShuffleDataType = F64;
using DsDataType = ck::Tuple<>;
using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
......@@ -176,7 +175,6 @@ int main(int argc, char* argv[])
BDataType,
ck::Tuple<>,
EDataType,
ComputeDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Scale>;
......
......@@ -31,10 +31,10 @@ template <index_t NumDimM,
typename BDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
typename CDEElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
......
......@@ -155,10 +155,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_
BF16,
BF16_Tuple,
BF16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
F16,
F16_Tuple,
F16,
F32,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
F32>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -41,10 +41,10 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_comp
F32,
F32_Tuple,
F32,
BF16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
Bilinear,
BF16>>>& instances)
{
add_device_operation_instances(
instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment