Commit 0e97ebaa authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make GSG work

parent 4cd24e66
......@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <vector>
#endif
#include "device_base.hpp"
......@@ -28,6 +29,7 @@ template <typename ALayout,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
......@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
} // namespace device
......
......@@ -2,9 +2,12 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -15,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -442,6 +443,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
D0sTransferSrcScalarPerVector>;
#ifndef __HIPCC_RTC__
// Argument
// FIXME: constness
struct Argument : public BaseArgument
......@@ -856,9 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op,
c1de_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
// FIXME: constness
......@@ -948,6 +948,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return str.str();
}
#endif
};
} // namespace device
......
......@@ -3,8 +3,12 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -126,7 +128,6 @@ __global__ void
// else
// AccElement = -INFINITY
// Otherwise, result may be wrong.
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,
......@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder.PadN,
MaskOutUpperTriangle>;
#ifndef __HIPCC_RTC__
// Argument
struct Argument : public BaseArgument
{
......@@ -604,7 +606,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
......@@ -700,7 +702,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -758,9 +760,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
b1_element_op, c_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -839,7 +839,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str();
}
#endif
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
......@@ -1054,7 +1054,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
assert(desc.is_valid);
// assert(desc.is_valid);
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
......
......@@ -2,9 +2,12 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -15,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
......@@ -431,6 +432,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t block_start_, block_end_;
};
#ifndef __HIPCC_RTC_
struct GroupDeviceArg
{
// lengths for the last dimensions of overall problem for sanity check of vector load/store
......@@ -587,7 +589,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
};
#ifndef __HIPCC_RTC_
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -780,7 +781,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto MakeArgument(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
......@@ -808,9 +809,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -889,6 +888,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
}
#endif
};
} // namespace device
......
......@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle
};
#ifndef __HIPCC_RTC__
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{
switch(s)
......@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default: return "Unrecognized specialization!";
}
}
#endif
struct MaskDisabledPredicate
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment