Commit 65a0dafd authored by Umang Yadav's avatar Umang Yadav
Browse files

Undo some changes

parent 0e97ebaa
......@@ -12,6 +12,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
#ifndef __HIPCC_RTC__
struct BaseArgument
{
BaseArgument() = default;
......@@ -23,7 +24,6 @@ struct BaseArgument
void* p_workspace_ = nullptr;
};
#ifndef __HIPCC_RTC__
struct BaseInvoker
{
BaseInvoker() = default;
......@@ -45,9 +45,9 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default;
#ifndef __HIPCC_RTC__
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
#ifndef __HIPCC_RTC__
virtual std::string GetTypeString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
......@@ -60,15 +60,14 @@ struct BaseOperator
return oss.str();
};
#endif
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{
// assert(p_arg);
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
#endif
virtual ~BaseOperator() {}
};
......
......@@ -36,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
......@@ -51,8 +52,7 @@ struct DeviceGemmMultipleD : public BaseOperator
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
......
......@@ -3,12 +3,8 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -19,6 +15,8 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -689,7 +687,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Batch Offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
};
#ifndef __HIPCC_RTC__
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -763,14 +761,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
......@@ -872,7 +869,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto
MakeArgument(const void* p_a,
const void* p_b,
......@@ -946,7 +943,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
cde_element_op);
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
......@@ -989,7 +985,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return str.str();
}
#endif
};
} // namespace device
......
......@@ -3,12 +3,8 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -19,6 +15,8 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
......@@ -761,7 +759,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
};
#ifndef __HIPCC_RTC__
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -844,6 +841,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -937,7 +935,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto
MakeArgument(const void* p_a,
const void* p_b,
......@@ -972,9 +970,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -1010,7 +1006,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
b_element_op,
cde_element_op);
}
#ifndef __HIPCC_RTC__
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
......@@ -1042,7 +1038,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
return str.str();
}
#endif
};
} // namespace device
......
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -15,6 +11,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -492,7 +490,6 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
CDEElementwiseOperation cde_element_op_;
};
#ifndef __HIPCC_RTC__
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -568,14 +565,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -595,7 +591,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
EDataType* p_e,
......@@ -629,9 +625,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -668,7 +662,6 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
cde_element_op);
}
#ifndef __HIPCC_RTC__
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
......@@ -692,7 +685,6 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
return str.str();
}
#endif
};
} // namespace device
......
......@@ -3,13 +3,8 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -19,6 +14,9 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
......@@ -467,7 +465,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
}
}
#ifndef __HIPCC_RTC__
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
......@@ -475,7 +472,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
}
#endif
// private:
const ADataType* p_a_grid_;
......@@ -501,7 +497,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
std::vector<index_t> raw_lengths_m_n_k_o_;
};
#ifndef __HIPCC_RTC__
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -585,13 +580,13 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -636,7 +631,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
......@@ -668,9 +662,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b1_element_op, c_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -720,7 +712,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
c_element_op);
}
#ifndef __HIPCC_RTC__
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
......@@ -750,7 +741,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return str.str();
}
#endif
};
} // namespace device
......
......@@ -642,9 +642,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -715,9 +715,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -894,9 +894,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -898,9 +898,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
reduce_out_element_ops,
Batch};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -357,9 +357,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BatchStrideC,
Batch};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......
......@@ -500,9 +500,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
StrideB,
StrideC};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
......
......@@ -714,9 +714,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -694,9 +694,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
input_left_pads,
input_right_pads};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid,
......
......@@ -903,9 +903,8 @@ struct
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......
......@@ -860,9 +860,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......
......@@ -830,9 +830,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......
......@@ -605,9 +605,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
input_left_pads,
input_right_pads};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
......
......@@ -203,9 +203,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -577,9 +577,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
wei_element_op,
out_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment