Commit 253f942b authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it compile

parent 8f9c0243
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -62,3 +65,5 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout, ...@@ -62,3 +65,5 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -70,3 +73,5 @@ struct DeviceGroupedContractionMultipleD : public BaseOperator ...@@ -70,3 +73,5 @@ struct DeviceGroupedContractionMultipleD : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -65,3 +68,5 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator ...@@ -65,3 +68,5 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -48,3 +51,5 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -48,3 +51,5 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -53,3 +56,5 @@ struct DeviceGroupedConvFwd : public BaseOperator ...@@ -53,3 +56,5 @@ struct DeviceGroupedConvFwd : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -63,3 +66,5 @@ struct DeviceGroupedConvFwdMultipleD : public BaseOperator ...@@ -63,3 +66,5 @@ struct DeviceGroupedConvFwdMultipleD : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -53,3 +56,5 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -53,3 +56,5 @@ struct DeviceGroupedGemm : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -73,3 +76,5 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator ...@@ -73,3 +76,5 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -37,3 +40,5 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout, ...@@ -37,3 +40,5 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -56,3 +59,5 @@ using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank, ...@@ -56,3 +59,5 @@ using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -60,3 +63,5 @@ using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, ...@@ -60,3 +63,5 @@ using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <array> #include <array>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#endif
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
...@@ -34,3 +39,5 @@ struct DevicePermute : BaseOperator ...@@ -34,3 +39,5 @@ struct DevicePermute : BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -45,3 +48,5 @@ struct DevicePoolFwd : public BaseOperator ...@@ -45,3 +48,5 @@ struct DevicePoolFwd : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -69,3 +72,5 @@ using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType, ...@@ -69,3 +72,5 @@ using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -70,3 +73,5 @@ using DeviceSoftmaxPtr = std::unique_ptr<DeviceSoftmax<InDataType, ...@@ -70,3 +73,5 @@ using DeviceSoftmaxPtr = std::unique_ptr<DeviceSoftmax<InDataType,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -63,3 +66,5 @@ struct DeviceSplitKContractionMultipleD : public BaseOperator ...@@ -63,3 +66,5 @@ struct DeviceSplitKContractionMultipleD : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -28,7 +31,7 @@ enum struct GemmSpecialization ...@@ -28,7 +31,7 @@ enum struct GemmSpecialization
NKOPadding, NKOPadding,
MNKOPadding, MNKOPadding,
}; };
#ifndef __HIPCC_RTC__
inline std::string getGemmSpecializationString(const GemmSpecialization& s) inline std::string getGemmSpecializationString(const GemmSpecialization& s)
{ {
switch(s) switch(s)
...@@ -52,7 +55,10 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s) ...@@ -52,7 +55,10 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
#endif
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +23,6 @@ ...@@ -15,8 +23,6 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -687,7 +693,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -687,7 +693,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Batch Offset // Batch Offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
}; };
#ifndef __HIPCC_RTC__
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -761,13 +767,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -761,13 +767,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
return true; return true;
} }
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
...@@ -869,7 +876,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -869,7 +876,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
#endif
static auto static auto
MakeArgument(const void* p_a, MakeArgument(const void* p_a,
const void* p_b, const void* p_b,
...@@ -943,7 +950,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -943,7 +950,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
cde_element_op); cde_element_op);
} }
static auto MakeInvoker() { return Invoker{}; } #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -985,8 +993,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -985,8 +993,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return str.str(); return str.str();
} }
#endif
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +23,6 @@ ...@@ -15,8 +23,6 @@
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
...@@ -759,6 +765,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -759,6 +765,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
}; };
#ifndef __HIPCC_RTC__
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -841,7 +848,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -841,7 +848,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -935,7 +941,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -935,7 +941,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
#endif
static auto static auto
MakeArgument(const void* p_a, MakeArgument(const void* p_a,
const void* p_b, const void* p_b,
...@@ -970,7 +976,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -970,7 +976,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
cde_element_op}; cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -1006,7 +1015,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -1006,7 +1015,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
b_element_op, b_element_op,
cde_element_op); cde_element_op);
} }
#ifndef __HIPCC_RTC__
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
...@@ -1038,8 +1047,11 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -1038,8 +1047,11 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
return str.str(); return str.str();
} }
#endif
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -11,8 +19,6 @@ ...@@ -11,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -490,6 +496,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -490,6 +496,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
}; };
#ifndef __HIPCC_RTC__
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -565,13 +572,14 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -565,13 +572,14 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
return true; return true;
} }
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -591,7 +599,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -591,7 +599,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
#endif
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
EDataType* p_e, EDataType* p_e,
...@@ -625,7 +633,10 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -625,7 +633,10 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
cde_element_op}; cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } #ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -662,6 +673,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -662,6 +673,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
cde_element_op); cde_element_op);
} }
#ifndef __HIPCC_RTC__
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
...@@ -685,8 +697,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout, ...@@ -685,8 +697,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
return str.str(); return str.str();
} }
#endif
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment