Commit 0b11569f authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute

parents e8d3a0fb fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
set(DEVICE_ELEMENTWISE_INSTANCE_SOURCE
device_normalize_instance.cpp
)
add_instance_library(device_elementwise_instance ${DEVICE_ELEMENTWISE_INSTANCE_SOURCE})
target_compile_features(device_elementwise_instance PUBLIC)
set_target_properties(device_elementwise_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_elementwise_instance)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using F16 = ck::half_t;
using F32 = float;
using inputType = F16;
using MeanType = F32;
using SquareMeanType = F32;
using GammaDataType = F16;
using BetaDataType = F16;
using outputType = F16;
using Normalize = ck::tensor_operation::element_wise::Normalize;
using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple<
// clang-format off
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 8, 8, 1, 1, 8, 8, 8 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 4, 4, 1, 1, 4, 4, 4 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 2, 2, 1, 1, 2, 2, 2 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 1, 1, 1, 1, 1, 1, 1 >
// clang-format on
>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceElementwisePtr<5, 1, 2, Normalize>>& instances)
{
add_device_operation_instances(
instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{});
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -28,14 +28,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ...@@ -28,14 +28,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment