Commit 253f942b authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it compile

parent 8f9c0243
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -14,9 +23,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
......@@ -465,6 +471,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
}
}
#ifndef __HIPCC_RTC__
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
......@@ -472,6 +479,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
}
#endif
// private:
const ADataType* p_a_grid_;
......@@ -497,6 +505,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
std::vector<index_t> raw_lengths_m_n_k_o_;
};
#ifndef __HIPCC_RTC__
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -580,13 +589,13 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -631,6 +640,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
......@@ -662,7 +672,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b1_element_op, c_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -712,6 +725,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
c_element_op);
}
#ifndef __HIPCC_RTC__
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
......@@ -741,8 +755,11 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return str.str();
}
#endif
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -642,8 +645,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -719,3 +723,5 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -715,8 +718,9 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -793,3 +797,5 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -894,8 +897,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -993,3 +997,5 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -898,8 +901,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
reduce_out_element_ops,
Batch};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -1005,3 +1009,5 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -856,8 +859,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op,
c1de_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
// FIXME: constness
......@@ -952,3 +956,5 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -685,8 +688,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
b1_element_op, c_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -770,3 +774,5 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -357,8 +360,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BatchStrideC,
Batch};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -433,3 +437,5 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -872,3 +875,5 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -820,3 +823,5 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -500,8 +503,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
StrideB,
StrideC};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
......@@ -585,3 +589,5 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -714,8 +717,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -787,3 +791,5 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -806,3 +809,5 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -694,8 +697,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
input_left_pads,
input_right_pads};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid,
......@@ -763,3 +767,5 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -903,8 +906,9 @@ struct
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -983,3 +987,5 @@ struct
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -860,8 +863,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -940,3 +944,5 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -830,8 +833,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -906,3 +910,5 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -605,8 +608,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
input_left_pads,
input_right_pads};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......@@ -677,3 +681,5 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -203,8 +206,9 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -266,3 +270,5 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
} // namespace tensor_operation
} // namespace ck
#endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -577,8 +580,9 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -660,3 +664,5 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
} // namespace tensor_operation
} // namespace ck
#endif
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment