Commit 253f942b authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it compile

parent 8f9c0243
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -1515,8 +1518,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid,
......@@ -1586,3 +1590,5 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -1399,8 +1402,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid,
......@@ -1473,3 +1477,5 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -325,8 +328,9 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
......@@ -336,3 +340,5 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -290,8 +293,9 @@ struct DeviceElementwiseImpl
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
......@@ -301,3 +305,5 @@ struct DeviceElementwiseImpl
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -595,3 +598,5 @@ struct DeviceElementwiseNormalizationImpl
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -771,8 +774,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -886,3 +890,5 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -546,8 +549,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -611,3 +615,5 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -603,8 +606,9 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -671,3 +675,5 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -996,8 +999,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
cde_element_op,
h_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -1085,3 +1089,5 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -611,8 +614,9 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
qs_element_op,
rs_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -691,3 +695,5 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -656,8 +659,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op,
cde_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......@@ -704,3 +708,5 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/array.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/common_header.hpp"
namespace ck {
......@@ -34,22 +39,22 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -223,9 +228,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
static auto MakeDsGridDescriptor_M_N(const ck::Array<index_t, NumDTensor>& MRaws,
const ck::Array<index_t, NumDTensor>& NRaws,
const ck::Array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
......@@ -310,14 +315,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
ck::Array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
ck::Array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -374,6 +379,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
#ifndef __HIPCC_RTC__
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
......@@ -382,6 +388,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
#endif
// private:
// pointers
......@@ -416,7 +423,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
index_t NRaw_;
index_t KRaw_;
};
#ifndef __HIPCC_RTC__
// Invoker
struct Invoker : public BaseInvoker
{
......@@ -492,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{
// check vector load/store
......@@ -574,7 +581,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -595,17 +602,17 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
ck::Array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
ck::Array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -627,20 +634,22 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
ck::Array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::Array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -662,6 +671,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
cde_element_op);
}
#ifndef __HIPCC_RTC__
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
......@@ -673,11 +683,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
{ LoopScheduler::Interwave,
"Interwave" }};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
{ PipelineVersion::v2,
"v2" }};
// clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle"
......@@ -706,6 +718,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str();
}
#endif
template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor
......@@ -722,10 +735,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_tuple()))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_tuple()))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
......@@ -735,7 +749,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
Block2ETileMap block_2_etile_map;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
......@@ -786,10 +800,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
}
constexpr bool IsValid() const
{
return is_valid;
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class DsDesc, class EDesc>
......@@ -807,14 +818,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
template <class Desc, class DsPointer>
__device__ static void Run(const Desc& desc,
__device__ static void Run(Desc desc,
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
......@@ -853,3 +866,5 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -732,8 +735,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -846,3 +850,5 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -526,8 +529,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -604,3 +608,5 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -226,8 +229,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -303,3 +307,5 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -228,8 +231,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -304,3 +308,5 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -701,8 +704,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
acc_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
......@@ -779,3 +783,5 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -462,8 +465,9 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -528,3 +532,5 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -280,8 +283,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateK0(K, KBatch),
KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -327,3 +331,5 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
......@@ -14,8 +21,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
......@@ -417,6 +422,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
}
};
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
......@@ -435,6 +441,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
#endif
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
......@@ -462,8 +469,9 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
b_element_op,
cde_element_op};
}
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic
std::unique_ptr<BaseArgument>
......@@ -525,3 +533,5 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
} // namespace device
} // namespace tensor_operation
} // namespace ck
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment