Commit bfa7e780 authored by Umang Yadav's avatar Umang Yadav
Browse files

Undo some more changes

parent 3b1e790e
......@@ -2,12 +2,9 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -18,6 +15,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -443,7 +442,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
D0sTransferSrcScalarPerVector>;
#ifndef __HIPCC_RTC__
// Argument
// FIXME: constness
struct Argument : public BaseArgument
......@@ -858,6 +856,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op,
c1de_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
......@@ -948,7 +947,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return str.str();
}
#endif
};
} // namespace device
......
......@@ -1054,7 +1054,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
// assert(desc.is_valid);
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
......
......@@ -2,12 +2,9 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -18,6 +15,8 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -432,7 +431,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t block_start_, block_end_;
};
#ifndef __HIPCC_RTC_
struct GroupDeviceArg
{
// lengths for the last dimensions of overall problem for sanity check of vector load/store
......@@ -589,6 +587,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -673,14 +672,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -888,7 +886,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
}
#endif
};
} // namespace device
......
......@@ -657,9 +657,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......
......@@ -386,9 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
in_elementwise_op,
acc_elementwise_op);
};
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment