Unverified Commit 7d50244e authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #209 from ROCm/andriy/merge_from_public

Update develop branch from public repository
parents f221c2b0 d51701d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = add_rmsnorm2d_rdquant_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename Traits_>
float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<
typename AddRmsnormRdquantTypeConfig<DataType>::ADataType,
typename AddRmsnormRdquantTypeConfig<DataType>::BDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::GammaDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::ComputeDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::XDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::YScaleDataType,
typename AddRmsnormRdquantTypeConfig<DataType>::QYDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveX,
Traits_::kThreePass>;
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<PipelineProblem>;
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kThreePass, ThreePassPipeline, OnePassPipeline>;
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
# run from top of ck folder
EXE=build/bin/tile_add_rmsnorm2d_rdquant_fwd
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_add_rmsnorm2d_rdquant_fwd
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
$EXE -prec=$pr_i -m=17 -n=16
$EXE -prec=$pr_i -m=1 -n=100
$EXE -prec=$pr_i -m=4 -n=128
$EXE -prec=$pr_i -m=80 -n=127
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=7 -n=599
$EXE -prec=$pr_i -m=19 -n=512
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=11 -n=510
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=91 -n=636
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=31 -n=1024
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=8 -n=1501
$EXE -prec=$pr_i -m=3 -n=1826
$EXE -prec=$pr_i -m=5 -n=2040
$EXE -prec=$pr_i -m=7 -n=2734
$EXE -prec=$pr_i -m=1 -n=3182
$EXE -prec=$pr_i -m=9 -n=4096
$EXE -prec=$pr_i -m=3 -n=8192
$EXE -prec=$pr_i -m=1 -n=10547
$EXE -prec=$pr_i -m=3 -n=17134
done
...@@ -7,3 +7,7 @@ add_subdirectory(02_layernorm2d) ...@@ -7,3 +7,7 @@ add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm) add_subdirectory(03_gemm)
add_subdirectory(04_img2col) add_subdirectory(04_img2col)
add_subdirectory(05_reduce) add_subdirectory(05_reduce)
add_subdirectory(06_permute)
add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant)
...@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
#define MEDIAN 1 #define MEDIAN 0
if(stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
...@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else #else
float total_time = 0; float total_time = 0;
#endif #endif
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
if constexpr(!TimePreprocess) if constexpr(!TimePreprocess)
...@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess(); preprocess();
} }
hipEvent_t start, stop; // hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start)); // hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop)); // hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize()); // hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time // calculate preprocess time
if constexpr(TimePreprocess) if constexpr(TimePreprocess)
{ {
...@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
// end real kernel // end real kernel
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop)); // hip_check_error(hipEventSynchronize(stop));
float cur_time = 0; // float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN // #if MEDIAN
times.insert(cur_time); // times.insert(cur_time);
#else // #else
total_time += cur_time; // total_time += cur_time;
#endif // #endif
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
static_cast<const void*>(gemm_args.p_a_grid), static_cast<const void*>(gemm_args.p_a_grid),
static_cast<const void*>(gemm_args.p_b_grid)); static_cast<const void*>(gemm_args.p_b_grid));
} }
} }
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
#if MEDIAN #if MEDIAN
auto mid = times.begin(); auto mid = times.begin();
...@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return (*mid + *mid_next) / 2; return (*mid + *mid_next) / 2;
} }
#else #else
return total_time / nrepeat; // return total_time / nrepeat;
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
return (total_time - preprocess_offset * nrepeat) / nrepeat;
#endif #endif
} }
else else
......
...@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA ...@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run<>(
a_thread_vec.template AsType<wmma_input_type_a>(), a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(), b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
...@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA ...@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run<>(
a_thread_vec.template AsType<wmma_input_type_a>(), a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(), b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......
...@@ -85,9 +85,9 @@ __global__ void ...@@ -85,9 +85,9 @@ __global__ void
BsPointer p_bs_grid, BsPointer p_bs_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -121,6 +121,19 @@ __global__ void ...@@ -121,6 +121,19 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
if constexpr(is_same_v<AElementwiseOperation, element_wise::DynamicUnaryOp>)
{
a_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<BElementwiseOperation, element_wise::DynamicUnaryOp>)
{
b_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<CDEElementwiseOperation, element_wise::DynamicUnaryOp>)
{
cde_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
......
...@@ -272,6 +272,26 @@ struct MultiplyMultiply ...@@ -272,6 +272,26 @@ struct MultiplyMultiply
e = ck::type_convert<ck::bhalf_t>(x0_f); e = ck::type_convert<ck::bhalf_t>(x0_f);
} }
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
}; };
struct MultiplyAddFastGelu struct MultiplyAddFastGelu
...@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu ...@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu
const float& d1) const const float& d1) const
{ {
const float x = c * alpha1_ + alpha2_ * d0 + d1; const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x); e = x > 0 ? x : 0;
} }
template <> template <>
...@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu ...@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu
type_convert<float>(d1); type_convert<float>(d1);
float result = 0; float result = 0;
Relu{}.template operator()<float>(result, x); result = x > 0 ? x : 0;
e = type_convert<half_t>(result); e = type_convert<half_t>(result);
} }
...@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu ...@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu
type_convert<float>(d1); type_convert<float>(d1);
float result = 0; float result = 0;
Relu{}.template operator()<float>(result, x); result = x > 0 ? x : 0;
e = type_convert<bhalf_t>(result); e = type_convert<bhalf_t>(result);
} }
...@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu ...@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1; const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0; float result = 0;
Relu{}.template operator()<float>(result, x); result = x > 0 ? x : 0;
e = type_convert<int8_t>(result); e = type_convert<int8_t>(result);
} }
......
...@@ -7,11 +7,38 @@ ...@@ -7,11 +7,38 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
#include <cassert>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
__host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0;
__host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0;
__host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0;
__host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0;
__host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -25,17 +52,30 @@ struct PassThroughPack2 ...@@ -25,17 +52,30 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough struct PassThrough final : public UnaryOpBase
{ {
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
__host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; }
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; }
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; }
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
{ {
...@@ -48,36 +88,12 @@ struct PassThrough ...@@ -48,36 +88,12 @@ struct PassThrough
y = type_convert<double>(x); y = type_convert<double>(x);
} }
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
y = type_convert<half_t>(x); y = type_convert<half_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -102,12 +118,6 @@ struct PassThrough ...@@ -102,12 +118,6 @@ struct PassThrough
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const __host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{ {
...@@ -407,20 +417,45 @@ struct UnarySquare ...@@ -407,20 +417,45 @@ struct UnarySquare
}; };
}; };
struct UnaryAbs struct UnaryAbs final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr UnaryAbs() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default;
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = ck::math::abs(x);
is_same<T, half_t>::value || is_same<T, int32_t>::value || }
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::abs(x); y = ck::math::abs(x);
}; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::abs(x);
}
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const __host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{ {
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x))); y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
...@@ -439,20 +474,41 @@ struct UnarySqrt ...@@ -439,20 +474,41 @@ struct UnarySqrt
}; };
}; };
struct Relu struct Relu final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr Relu() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr Relu(const Relu&) = default;
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
} }
template <> __host__ __device__ inline void operator()(double& y, const double& x) const final
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const {
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
...@@ -599,18 +655,52 @@ struct Gelu ...@@ -599,18 +655,52 @@ struct Gelu
} }
}; };
struct Sigmoid struct Sigmoid final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr Sigmoid() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default;
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || constexpr float one = type_convert<float>(1);
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || y = one / (one + ck::math::exp(-x));
is_same<T, int32_t>::value, }
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = one / (one + ck::math::exp(-x)); {
}; constexpr double one = type_convert<double>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
constexpr int32_t one = type_convert<int32_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
constexpr int8_t one = type_convert<int8_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
constexpr half_t one = type_convert<half_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
constexpr float one = type_convert<float>(1);
float x_f32 = ck::type_convert<float>(x);
float y_f32 = one / (one + ck::math::exp(x_f32));
y = ck::type_convert<bhalf_t>(y_f32);
}
}; };
struct Silu struct Silu
...@@ -626,18 +716,44 @@ struct Silu ...@@ -626,18 +716,44 @@ struct Silu
}; };
}; };
struct TanH struct TanH final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr TanH() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr TanH(const TanH&) = default;
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = ck::math::tanh(x);
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || }
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::tanh(x); y = ck::math::tanh(x);
}; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::tanh(x);
}
}; };
struct ACos struct ACos
...@@ -878,138 +994,418 @@ struct Rcp ...@@ -878,138 +994,418 @@ struct Rcp
}; };
}; };
struct Swish struct Swish final : public UnaryOpBase
{ {
Swish(float beta = 1.0f) : beta_(beta) {} __host__ __device__ constexpr Swish(const Swish&) = default;
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; }
const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<float>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<double>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int32_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int8_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<half_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<bhalf_t>(x / (1.f + ck::math::exp(bx)));
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<X, float>::value || is_same<X, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value, is_same<X, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value || static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value, is_same<Y, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; }
const float beta_;
}; };
struct SoftRelu struct SoftRelu final : public UnaryOpBase
{ {
SoftRelu(float alpha = 1.f) : alpha_(alpha){}; __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default;
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
template <typename T> __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || float casted_alpha = type_convert<float>(alpha_);
is_same<T, half_t>::value || is_same<T, int32_t>::value || constexpr float one = type_convert<float>(1);
is_same<T, int8_t>::value, y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
"Data type is not supported by this operation!"); }
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; {
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
constexpr bhalf_t one = type_convert<bhalf_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_;
}; };
struct Power struct Power final : public UnaryOpBase
{ {
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) __host__ __device__ constexpr Power(const Power&) = default;
: alpha_(alpha), beta_(beta), gamma_(gamma){}; __host__ __device__ constexpr Power(Power&&) = default;
__host__ __device__ ~Power() = default;
template <typename T> __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
__host__ __device__ void operator()(T& y, const T& x) const : alpha_(alpha), beta_(beta), gamma_(gamma)
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
__host__ __device__ float get_gamma() const { return gamma_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
const float gamma_; const float gamma_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
float casted_gamma = type_convert<float>(gamma_);
float shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
double casted_gamma = type_convert<double>(gamma_);
double shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
int32_t casted_gamma = type_convert<int32_t>(gamma_);
int32_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
int8_t casted_gamma = type_convert<int8_t>(gamma_);
int8_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
half_t casted_gamma = type_convert<half_t>(gamma_);
half_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
bhalf_t casted_gamma = type_convert<bhalf_t>(gamma_);
bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
}; };
struct ClippedRelu struct ClippedRelu final : public UnaryOpBase
{ {
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default;
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
template <typename T> __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f)
__host__ __device__ void operator()(T& y, const T& x) const : alpha_(alpha), beta_(beta)
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
}; };
struct LeakyRelu struct LeakyRelu final : public UnaryOpBase
{ {
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default;
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
template <typename T> __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y,
[[maybe_unused]] const bhalf_t& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
} }
const float alpha_;
}; };
struct Elu struct Elu final : public UnaryOpBase
{ {
Elu(float alpha = 1.f) : alpha_(alpha){}; __host__ __device__ constexpr Elu(const Elu&) = default;
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
template <typename T> __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || float casted_alpha = type_convert<float>(alpha_);
is_same<T, half_t>::value || is_same<T, int32_t>::value || y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
is_same<T, int8_t>::value, }
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); {
double casted_alpha = type_convert<double>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
} }
const float alpha_;
}; };
struct Logistic struct Logistic final : public UnaryOpBase
{ {
Logistic(float alpha = 1.f) : alpha_(alpha){}; __host__ __device__ constexpr Logistic(const Logistic&) = default;
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
template <typename T> __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || float casted_alpha = type_convert<float>(alpha_);
is_same<T, half_t>::value || is_same<T, int32_t>::value || constexpr float one = type_convert<float>(1);
is_same<T, int8_t>::value, y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
"Data type is not supported by this operation!"); }
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); {
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
constexpr bhalf_t one = type_convert<bhalf_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
} }
const float alpha_;
}; };
struct ConvInvscale struct ConvInvscale
...@@ -1074,7 +1470,7 @@ struct ConvScaleRelu ...@@ -1074,7 +1470,7 @@ struct ConvScaleRelu
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
{ {
float x; float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_); Relu{}(x, c * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_); e = type_convert<f8_t>(x * scale_out_);
}; };
...@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
}; };
struct DynamicUnaryOp
{
DynamicUnaryOp& operator=(const DynamicUnaryOp& other)
{
if(this != &other)
{
unary_op_ptr_ = other.unary_op_ptr_;
unary_op_type_ = other.unary_op_type_;
}
return *this;
}
__host__ __device__ DynamicUnaryOp() = delete;
__host__ __device__ DynamicUnaryOp(const Swish& swish)
{
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
}
__host__ __device__ DynamicUnaryOp(const Swish&& swish)
{
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
}
__host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; }
__host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; }
__host__ __device__ DynamicUnaryOp(const PassThrough&)
{
unary_op_type_ = UnaryOpType::PassThrough;
}
__host__ __device__ DynamicUnaryOp(const PassThrough&&)
{
unary_op_type_ = UnaryOpType::PassThrough;
}
__host__ __device__ DynamicUnaryOp(const Logistic& logistic)
{
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const Logistic&& logistic)
{
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; }
__host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; }
__host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; }
__host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; }
__host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu)
{
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu)
{
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; }
__host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; }
__host__ __device__ DynamicUnaryOp(const Power& pow)
{
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
}
__host__ __device__ DynamicUnaryOp(const Power&& pow)
{
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
}
__host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu)
{
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
}
__host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu)
{
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
}
__host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu)
{
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu)
{
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const Elu& elu)
{
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const Elu&& elu)
{
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op)
: unary_op_type_(dynamic_op.unary_op_type_),
unary_op_ptr_(dynamic_op.unary_op_ptr_),
alpha(dynamic_op.alpha),
beta(dynamic_op.beta),
gamma(dynamic_op.gamma)
{
}
__host__ __device__ ~DynamicUnaryOp()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
}
__device__ void InitUnaryOpPtrOnDevice()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break;
case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break;
case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break;
case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break;
case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break;
case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break;
case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break;
case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break;
case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break;
case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break;
case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break;
case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break;
default: unary_op_ptr_ = nullptr; break;
}
}
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
unary_op_ptr_->operator()(y, x);
}
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
switch(unary_op_type_)
{
case(UnaryOpType::Swish): Swish{}.operator()(y, x); break;
case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break;
case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break;
case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break;
case(UnaryOpType::TanH): TanH{}.operator()(y, x); break;
case(UnaryOpType::Relu): Relu{}.operator()(y, x); break;
case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break;
case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break;
case(UnaryOpType::Power): Power{}.operator()(y, x); break;
case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break;
case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break;
case(UnaryOpType::Elu): Elu{}.operator()(y, x); break;
default: break;
}
}
template <typename X, typename Y>
__device__ __host__ constexpr void isSupported() const
{
static_assert(std::is_same<X, Y>::value, "X and Y must be of the same type");
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, bhalf_t>::value || is_same<X, half_t>::value ||
is_same<X, int32_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!");
}
private:
enum class UnaryOpType
{
Swish,
Sigmoid,
PassThrough,
Logistic,
TanH,
Relu,
SoftRelu,
UnaryAbs,
Power,
ClippedRelu,
LeakyRelu,
Elu
};
public:
UnaryOpType unary_op_type_;
UnaryOpBase* unary_op_ptr_ = nullptr;
float alpha;
float beta;
float gamma;
};
#pragma clang diagnostic pop
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16> ...@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a), __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b), bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
0, 0,
0, 0,
0); 0);
} }
}; };
......
...@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t> ...@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t>
static constexpr int bias = 16; // negative zero nan mode static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode // static constexpr int bias = 15; // ieee mode
}; };
template <>
struct NumericUtils<bhalf_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int bias = 128; // negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
} // namespace ck } // namespace ck
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp" #include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
...@@ -49,14 +51,17 @@ ...@@ -49,14 +51,17 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
......
...@@ -23,6 +23,7 @@ enum struct coord_transform_enum ...@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate, replicate,
xor_t, xor_t,
offset, offset,
indexing,
}; };
template <index_t NDimLow, index_t NDimUp> template <index_t NDimLow, index_t NDimUp>
...@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1> ...@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
} }
}; };
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
IndexingAdaptor iadaptor_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const IndexingAdaptor& iadaptor)
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.calculate_lower_index(idx_low, idx_up);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
//******************************************************************************************************* //*******************************************************************************************************
template <typename LowLength> template <typename LowLength>
...@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le ...@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
} }
} // namespace ck_tile } // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace ck_tile {
template <typename UpLength, typename Indices>
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
const Indices& indices)
{
// by default we use the simplest one
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
}
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE constexpr auto
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
{
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment