Commit e0041ad8 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/drop_cshuffle

parents 3239201e ac9e01e2
......@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
......@@ -66,8 +67,26 @@ struct ReferenceGemm : public device::BaseOperator
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBiasActivation : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_{c0_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<CDataType>& c0_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBiasActivation::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c, v_acc, static_cast<float>(arg.c0_n_(n)));
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBiasActivation"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
const Tensor<CDataType>& c1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_{c0_n},
c1_m_n_{c1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<CDataType>& c0_n_;
const Tensor<CDataType>& c1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBiasActivationAdd::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c,
v_acc,
static_cast<float>(arg.c0_n_(n)),
static_cast<float>(arg.c1_m_n_(m, n)));
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
const Tensor<CDataType>& c1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBiasActivationAdd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -90,10 +90,13 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m)
{
AccDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n)
{
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) / sqrt(var(m) + arg.epsilon_);
auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n);
arg.acc_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
struct ReferencePoolingFwd : public device::BaseOperator
{
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in),
out_(out),
out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides),
in_left_pads_(in_left_pads),
reduceLength_(1)
{
static_for<0, WindowRank, 1>{}(
[&](auto I) { reduceLength_ *= window_spatial_lengths[I]; });
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float RunPooling3dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
ComputeDataType,
IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
};
return 0;
}
float RunPooling2dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
ComputeDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_indices_(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
};
return 0;
}
float Run(const Argument& arg)
{
// TODO - support generic pooling
if constexpr(InOutRank == 5 && WindowRank == 3)
return RunPooling3dFwd(arg);
else if constexpr(InOutRank == 4 && WindowRank == 2)
return RunPooling2dFwd(arg);
else
throw std::runtime_error("Only support pooling3d or pooling2d so far");
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads)
{
return Argument{in,
out,
out_indices,
window_spatial_lengths,
window_strides,
in_left_pads,
in_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferencePoolingFwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/ck.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
struct ReferenceReduce : public device::DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{
using IndexDataType = int32_t;
static constexpr int NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
double alpha,
double beta,
const InDataType* in_host,
OutDataType* out_host,
IndexDataType* out_index_host,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: reduceDims_(reduceDims),
outLengths_(outLengths),
outStrides_(outStrides),
in_host_(in_host),
out_host_(out_host),
out_index_host_(out_index_host),
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
if constexpr(NumInvariantDim > 0)
{
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = inLengths[dim];
i++;
};
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = inLengths[dim];
};
if constexpr(NumInvariantDim > 0)
{
// check invariant_lengths_ and outLengths
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != outLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
}
if constexpr(NumInvariantDim > 0)
{
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
in_invariant_strides_[i] = inStrides[dim];
i++;
};
};
for(int j = 0, i = 0; j < NumReduceDim; j++)
{
int dim = reduceDims_[j];
in_reduce_strides_[i] = inStrides[dim];
i++;
};
if constexpr(NumInvariantDim > 0)
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumReduceDim>(reduce_lengths_);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
};
const std::array<int, NumReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumReduceDim> reduce_lengths_;
const std::array<index_t, NumDstDim> outLengths_;
const std::array<index_t, NumDstDim> outStrides_;
std::array<index_t, NumInvariantDim> in_invariant_strides_;
std::array<index_t, NumReduceDim> in_reduce_strides_;
const InDataType* in_host_;
OutDataType* out_host_;
IndexDataType* out_index_host_;
const InElementwiseOperation in_elementwise_op_;
const AccElementwiseOperation acc_elementwise_op_;
AccDataType alpha_;
AccDataType beta_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumReduceDim>> reduce_index_set_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ignore = stream_config;
using ck::float_equal_one;
using ck::float_equal_zero;
using ck::type_convert;
using ck::host_common::get_index_set;
using ck::host_common::get_offset_from_index;
if constexpr(OutputIndex)
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < arg.reduce_index_set_.size(); i++)
{
auto in_offset = get_offset_from_index<NumReduceDim>(
arg.in_reduce_strides_, arg.reduce_index_set_[i]);
auto currVal = type_convert<AccDataType>(arg.in_host_[in_offset]);
arg.in_elementwise_op_(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[0]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[0] = type_convert<OutDataType>(accuVal);
arg.out_index_host_[0] = accuIndex;
}
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal =
ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
auto in_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.in_invariant_strides_, invariant_index);
for(std::size_t i = 0; i < arg.reduce_index_set_.size(); i++)
{
auto in_reduce_offset = get_offset_from_index<NumReduceDim>(
arg.in_reduce_strides_, arg.reduce_index_set_[i]);
auto currVal = type_convert<AccDataType>(
arg.in_host_[in_invariant_offset + in_reduce_offset]);
arg.in_elementwise_op_(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
auto dst_offset = get_offset_from_index<NumInvariantDim>(arg.outStrides_,
invariant_index);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[dst_offset]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[dst_offset] = type_convert<OutDataType>(accuVal);
arg.out_index_host_[dst_offset] = accuIndex;
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end =
std::min((it + 1) * work_per_thread, arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; i++)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
};
}
else
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
if constexpr(NumInvariantDim == 0)
{
AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : arg.reduce_index_set_)
{
auto in_offset = get_offset_from_index<NumReduceDim>(arg.in_reduce_strides_,
reduce_index);
auto currVal = type_convert<AccDataType>(arg.in_host_[in_offset]);
arg.in_elementwise_op_(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[0]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[0] = type_convert<OutDataType>(accuVal);
}
else
{
auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal =
ReduceOperation::template GetIdentityValue<AccDataType>();
auto in_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.in_invariant_strides_, invariant_index);
for(const auto& reduce_index : arg.reduce_index_set_)
{
auto in_reduce_offset = get_offset_from_index<NumReduceDim>(
arg.in_reduce_strides_, reduce_index);
auto currVal = type_convert<AccDataType>(
arg.in_host_[in_invariant_offset + in_reduce_offset]);
arg.in_elementwise_op_(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
};
arg.acc_elementwise_op_(accuVal, accuVal);
if(!float_equal_one{}(arg.alpha_))
accuVal *= type_convert<AccDataType>(arg.alpha_);
auto dst_offset = get_offset_from_index<NumInvariantDim>(arg.outStrides_,
invariant_index);
if(!float_equal_zero{}(arg.beta_))
accuVal += type_convert<AccDataType>(arg.out_host_[dst_offset]) *
type_convert<AccDataType>(arg.beta_);
arg.out_host_[dst_offset] = type_convert<OutDataType>(accuVal);
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end =
std::min((it + 1) * work_per_thread, arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; i++)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
};
};
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
ignore = p_arg;
return true;
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
double alpha,
double beta,
const void* in_host,
const void* in_index_host,
void* out_host,
void* out_index_host,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
ignore = in_index_host;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_host),
static_cast<OutDataType*>(out_host),
static_cast<IndexDataType*>(out_index_host),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_Reduce<" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -24,11 +24,14 @@ struct ReferenceSoftmax : public device::BaseOperator
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
: in_(in), out_(out), sm_reduce_dims_(sm_reduce_dims)
{
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
// std::cout << "debug: scalar dims: ";
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{
......@@ -143,8 +146,8 @@ struct ReferenceSoftmax : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, sm_reduce_dims};
......
......@@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>;
using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F64_Tuple = ck::Tuple<F64>;
using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
......@@ -85,11 +86,17 @@ using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......@@ -98,6 +105,10 @@ template <typename Activation>
using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Add_Mul_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
......@@ -105,6 +116,10 @@ template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Mul2_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory;
......
......@@ -3,8 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
......
......@@ -3,8 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
std::vector<std::unique_ptr<
DeviceBatchedContractionMultipleD<1,
2,
3,
1,
F16,
F16,
F16_Tuple,
F16,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>>>& instances);
// Contraction + add
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedContractionMultipleD<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>>
{
using DeviceOp =
DeviceBatchedContractionMultipleD<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, ck::half_t> && is_same_v<BDataType, ck::half_t> &&
is_same_v<DDataType, ck::half_t> && is_same_v<EDataType, ck::half_t>)
{
if constexpr(NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1)
{
add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,8 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
......
......@@ -3,8 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP32
void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// BF16
void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP64
void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
template <typename XDataType,
typename YDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
index_t Rank>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
ck::tensor_operation::element_wise::NormalizeInInfer,
Rank>>
{
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
ck::tensor_operation::element_wise::NormalizeInInfer,
Rank>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -3,10 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
......@@ -19,6 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// float
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
......@@ -67,6 +66,55 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough,
Bilinear>>>& instances);
// double
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear>>>& instances);
// Contraction + Bilinear
template <index_t NumDimM,
index_t NumDimN,
......@@ -118,6 +166,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
}
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<DDataType, double> && is_same_v<EDataType, double>)
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
......
......@@ -3,10 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
......@@ -19,6 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// float
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
......@@ -67,6 +66,55 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough,
Scale>>>& instances);
// double
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F64,
F64,
Empty_Tuple,
F64,
PassThrough,
PassThrough,
Scale>>>& instances);
// Contraction + Scale
template <index_t NumDimM,
index_t NumDimN,
......@@ -117,6 +165,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
}
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<EDataType, double>)
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
......
......@@ -3,8 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
......
......@@ -3,8 +3,8 @@
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
......
......@@ -3,11 +3,10 @@
#pragma once
#include <cstdlib>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -18,11 +17,8 @@ namespace device {
namespace instance {
using Normalize = ck::tensor_operation::element_wise::Normalize;
using DeviceNormalizeFromMeanMeanSquarePtr = ck::tensor_operation::device::DeviceElementwiseBasePtr<
Tuple<half_t, float, float, half_t, half_t>,
Tuple<half_t>,
Normalize,
2>;
using DeviceNormalizeFromMeanMeanSquarePtr = ck::tensor_operation::device::
DeviceElementwisePtr<Tuple<half_t, float, float, half_t, half_t>, Tuple<half_t>, Normalize, 2>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment