Commit 253f942b authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it compile

parent 8f9c0243
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -1515,8 +1518,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1515,8 +1518,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
wei_element_op, wei_element_op,
out_element_op}; out_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid, MakeArgumentPointer(void* p_in_grid,
...@@ -1586,3 +1590,5 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1586,3 +1590,5 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -1399,8 +1402,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1399,8 +1402,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads, input_left_pads,
input_right_pads}; input_right_pads};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid, MakeArgumentPointer(void* p_in_grid,
...@@ -1473,3 +1477,5 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1473,3 +1477,5 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -325,8 +328,9 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -325,8 +328,9 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
out_dev_buffers, out_dev_buffers,
elementwise_op); elementwise_op);
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
...@@ -336,3 +340,5 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -336,3 +340,5 @@ struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -290,8 +293,9 @@ struct DeviceElementwiseImpl ...@@ -290,8 +293,9 @@ struct DeviceElementwiseImpl
out_dev_buffers, out_dev_buffers,
elementwise_op); elementwise_op);
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
...@@ -301,3 +305,5 @@ struct DeviceElementwiseImpl ...@@ -301,3 +305,5 @@ struct DeviceElementwiseImpl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -595,3 +598,5 @@ struct DeviceElementwiseNormalizationImpl ...@@ -595,3 +598,5 @@ struct DeviceElementwiseNormalizationImpl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -771,8 +774,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO ...@@ -771,8 +774,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
reduce_in_element_ops, reduce_in_element_ops,
reduce_out_element_ops}; reduce_out_element_ops};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -886,3 +890,5 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO ...@@ -886,3 +890,5 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -546,8 +549,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -546,8 +549,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
b_element_op, b_element_op,
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -611,3 +615,5 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -611,3 +615,5 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -603,8 +606,9 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -603,8 +606,9 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
b_element_op, b_element_op,
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -671,3 +675,5 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -671,3 +675,5 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -996,8 +999,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -996,8 +999,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
cde_element_op, cde_element_op,
h_element_op}; h_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -1085,3 +1089,5 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -1085,3 +1089,5 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -611,8 +614,9 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -611,8 +614,9 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
qs_element_op, qs_element_op,
rs_element_op}; rs_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -691,3 +695,5 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -691,3 +695,5 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -656,8 +659,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -656,8 +659,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -704,3 +708,5 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -704,3 +708,5 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/array.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/utility/common_header.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
...@@ -34,22 +39,22 @@ template <typename GridwiseGemm, ...@@ -34,22 +39,22 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid, kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -223,9 +228,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -223,9 +228,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, static auto MakeDsGridDescriptor_M_N(const ck::Array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const ck::Array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const ck::Array<index_t, NumDTensor>& DsStride)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -310,14 +315,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -310,14 +315,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
Argument(const void* p_a_grid, Argument(const void* p_a_grid,
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, ck::Array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, ck::Array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -374,6 +379,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -374,6 +379,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
} }
#ifndef __HIPCC_RTC__
void Print() const void Print() const
{ {
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
...@@ -382,6 +388,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -382,6 +388,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
} }
#endif
// private: // private:
// pointers // pointers
...@@ -416,7 +423,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -416,7 +423,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
index_t NRaw_; index_t NRaw_;
index_t KRaw_; index_t KRaw_;
}; };
#ifndef __HIPCC_RTC__
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -492,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -492,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
#endif
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_) static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{ {
// check vector load/store // check vector load/store
...@@ -574,7 +581,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -574,7 +581,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
return true; return true;
} }
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -595,17 +602,17 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -595,17 +602,17 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
#endif
static auto MakeArgument(const void* p_a, static auto MakeArgument(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, ck::Array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, ck::Array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -627,20 +634,22 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -627,20 +634,22 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, ck::Array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, ck::Array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -662,6 +671,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -662,6 +671,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
cde_element_op); cde_element_op);
} }
#ifndef __HIPCC_RTC__
// polymorphic // polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
...@@ -673,11 +683,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -673,11 +683,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{ std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; { LoopScheduler::Interwave,
"Interwave" }};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"}, std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}}; { PipelineVersion::v2,
"v2" }};
// clang-format off // clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle" str << "DeviceGemmMultipleD_Xdl_CShuffle"
...@@ -706,6 +718,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -706,6 +718,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
#endif
template <class ADesc, class BDesc, class DsDesc, class EDesc> template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor struct Descriptor
...@@ -722,10 +735,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -722,10 +735,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>; DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_tuple()))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( ds_tuple()))>;
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap( using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
...@@ -735,7 +749,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -735,7 +749,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
Block2ETileMap block_2_etile_map; Block2ETileMap block_2_etile_map;
// element-wise op // element-wise op
AElementwiseOperation a_element_op; AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op; BElementwiseOperation b_element_op;
...@@ -786,10 +800,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -786,10 +800,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
} }
constexpr bool IsValid() const constexpr bool IsValid() const { return is_valid; }
{
return is_valid;
}
}; };
template <class ADesc, class BDesc, class DsDesc, class EDesc> template <class ADesc, class BDesc, class DsDesc, class EDesc>
...@@ -807,14 +818,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -807,14 +818,16 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
template <class Desc, class DsPointer> template <class Desc, class DsPointer>
__device__ static void Run(const Desc& desc, __device__ static void Run(Desc desc,
const ADataType* __restrict__ p_a_grid, const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid) EDataType* __restrict__ p_e_grid)
{ {
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.is_valid); assert(desc.is_valid);
#endif
if(desc.has_main_k_block_loop) if(desc.has_main_k_block_loop)
{ {
GridwiseGemm::template Run<true>(p_a_grid, GridwiseGemm::template Run<true>(p_a_grid,
...@@ -853,3 +866,5 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -853,3 +866,5 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -732,8 +735,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -732,8 +735,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
reduce_in_element_ops, reduce_in_element_ops,
reduce_out_element_ops}; reduce_out_element_ops};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -846,3 +850,5 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -846,3 +850,5 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -526,8 +529,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -526,8 +529,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op, b_element_op,
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -604,3 +608,5 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -604,3 +608,5 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -226,8 +229,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -226,8 +229,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -303,3 +307,5 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -303,3 +307,5 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -228,8 +231,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -228,8 +231,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -304,3 +308,5 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -304,3 +308,5 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -701,8 +704,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -701,8 +704,9 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
acc_element_op, acc_element_op,
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -779,3 +783,5 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -779,3 +783,5 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -462,8 +465,9 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -462,8 +465,9 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
b_element_op, b_element_op,
c_element_op}; c_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -528,3 +532,5 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout, ...@@ -528,3 +532,5 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -280,8 +283,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -280,8 +283,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0(K, KBatch),
KBatch}; KBatch};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -327,3 +331,5 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -327,3 +331,5 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -14,8 +21,6 @@ ...@@ -14,8 +21,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
...@@ -417,6 +422,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -417,6 +422,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
} }
}; };
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
...@@ -435,6 +441,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -435,6 +441,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
#endif
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
...@@ -462,8 +469,9 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -462,8 +469,9 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
b_element_op, b_element_op,
cde_element_op}; cde_element_op};
} }
#ifndef __HIPCC_RTC__
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -525,3 +533,5 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -525,3 +533,5 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment