Commit 0e97ebaa authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make GSG work

parent 4cd24e66
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#endif
#include "device_base.hpp" #include "device_base.hpp"
...@@ -28,6 +29,7 @@ template <typename ALayout, ...@@ -28,6 +29,7 @@ template <typename ALayout,
bool MaskOutUpperTriangle> // TODO: enum for mask type bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{ {
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b0, const void* p_b0,
...@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator ...@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
} // namespace device } // namespace device
......
...@@ -2,9 +2,12 @@ ...@@ -2,9 +2,12 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +18,6 @@ ...@@ -15,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -442,6 +443,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -442,6 +443,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
D0sTransferSrcScalarPerVector>; D0sTransferSrcScalarPerVector>;
#ifndef __HIPCC_RTC__
// Argument // Argument
// FIXME: constness // FIXME: constness
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -856,9 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -856,9 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_element_op, b1_element_op,
c1de_element_op}; c1de_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
// FIXME: constness // FIXME: constness
...@@ -948,6 +948,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -948,6 +948,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return str.str(); return str.str();
} }
#endif
}; };
} // namespace device } // namespace device
......
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +19,6 @@ ...@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -126,7 +128,6 @@ __global__ void ...@@ -126,7 +128,6 @@ __global__ void
// else // else
// AccElement = -INFINITY // AccElement = -INFINITY
// Otherwise, result may be wrong. // Otherwise, result may be wrong.
template <typename ALayout, template <typename ALayout,
typename BLayout, // B0Layout typename BLayout, // B0Layout
typename B1Layout, typename B1Layout,
...@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -430,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder.PadN, matrix_padder.PadN,
MaskOutUpperTriangle>; MaskOutUpperTriangle>;
#ifndef __HIPCC_RTC__
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -604,7 +606,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -604,7 +606,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
...@@ -700,7 +702,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -700,7 +702,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true; return true;
} }
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -758,9 +760,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -758,9 +760,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
b1_element_op, c_element_op}; b1_element_op, c_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -839,7 +839,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -839,7 +839,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str(); return str.str();
} }
#endif
template <class ADesc, class BDesc, class B1Desc, class CDesc> template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor struct Descriptor
{ {
...@@ -1054,7 +1054,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1054,7 +1054,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const ADataType* __restrict__ p_b1_grid, const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid) CDataType* __restrict__ p_c_grid)
{ {
assert(desc.is_valid); // assert(desc.is_valid);
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale}; AccElementwiseOperation acc_element_op{scale};
......
...@@ -2,9 +2,12 @@ ...@@ -2,9 +2,12 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -15,8 +18,6 @@ ...@@ -15,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -431,6 +432,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -431,6 +432,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t block_start_, block_end_; index_t block_start_, block_end_;
}; };
#ifndef __HIPCC_RTC_
struct GroupDeviceArg struct GroupDeviceArg
{ {
// lengths for the last dimensions of overall problem for sanity check of vector load/store // lengths for the last dimensions of overall problem for sanity check of vector load/store
...@@ -587,7 +589,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -587,7 +589,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
}; };
#ifndef __HIPCC_RTC_
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -780,7 +781,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -780,7 +781,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
#endif
static auto MakeArgument(std::vector<const void*> p_a_vec, static auto MakeArgument(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
...@@ -808,9 +809,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -808,9 +809,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -889,6 +888,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -889,6 +888,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg); return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
} }
#endif
}; };
} // namespace device } // namespace device
......
...@@ -13,6 +13,7 @@ enum struct MaskingSpecialization ...@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle MaskOutUpperTriangle
}; };
#ifndef __HIPCC_RTC__
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{ {
switch(s) switch(s)
...@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s ...@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
#endif
struct MaskDisabledPredicate struct MaskDisabledPredicate
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment