Commit b097be17 authored by root's avatar root
Browse files

merge changes for upstream/latest update

parents 8a891bbd a49115b9
......@@ -18,12 +18,12 @@ struct GeneratorTensor_0
template <typename T>
struct GeneratorTensor_1
{
int value = 1;
T value = 1;
template <typename... Is>
T operator()(Is...)
{
return ck::type_convert<T>(value);
return value;
}
};
......
......@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
arg.in_element_op_(v_acc, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
......
......@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for(int k = 0; k < K; ++k)
{
arg.a_element_op_(a, arg.a_m_k_(m, k));
arg.b_element_op_(b, arg.b_k_n_(k, n));
arg.a_element_op_(a, ck::type_convert<AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(b, ck::type_convert<AccDataType>(arg.b_k_n_(k, n)));
acc += a * b;
}
......
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename InDataType, typename OutDataType, typename AccDataType>
struct ReferenceSoftmax : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
{
// std::cout << "debug: scalar dims: ";
for(int i = 0; i < rank; i++)
{
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end())
{
sm_scalar_dims_.push_back(i);
// std::cout << i << ", ";
}
}
// std::cout << std::endl;
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
AccDataType alpha_;
AccDataType beta_;
index_t rank_;
std::vector<index_t> sm_reduce_dims_;
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
std::vector<size_t> scalar_lengths;
for(index_t dim : arg.sm_scalar_dims_)
{
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
}
Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue(
GeneratorTensor_1<AccDataType>{std::numeric_limits<AccDataType>::lowest()});
Tensor<AccDataType> reduce_sum(scalar_lengths);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto to_sm_scalar_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_)
{
sm_scalar_idx.push_back(idx[dim]);
}
return sm_scalar_idx;
};
arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)),
static_cast<AccDataType>(self(idx)));
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl;
Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx)));
});
// LogRangeAsType<float>(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl;
in_stable.ForEach([&](auto& self, auto idx) {
// denominator = sum(exp(x - max(x)))
reduce_sum(to_sm_scalar_idx(idx)) += self(idx);
});
// LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") <<
// std::endl;
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
// reduction along reduce dims
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",")
// << std::endl;
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, rank, sm_reduce_dims};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceSoftmax"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP
#define CK_DEVICE_OPERATION_INSTANCE_HPP
#pragma once
#include <stdlib.h>
#include <vector>
namespace ck {
namespace tensor_operation {
......@@ -23,4 +22,3 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOpId>
template <ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
......@@ -75,14 +75,13 @@ template <typename InDataType,
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
......@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise(
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
......@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType =
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
AccElementwiseOperation>;
template <ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr<
typename reduce_unary_operator<ReduceOperation, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOperation, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
......@@ -77,15 +75,13 @@ template <typename InDataType,
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_multiblock_atomic_add(
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>&
device_op_instances)
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
......@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
device_op_instances)
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add(
Rank, \
NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
......@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOpId>
template <ReduceTensorOp ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
......@@ -61,14 +61,13 @@ template <typename InDataType,
bool PropagateNan,
bool UseIndex>
void add_device_reduce_instance_threadwise(
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>>& device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation;
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
......@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise(
Rank, \
NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
#ifndef CHECK_ERR_HPP
#define CHECK_ERR_HPP
#pragma once
#include <algorithm>
#include <cmath>
......@@ -169,17 +168,34 @@ check_err(const std::vector<T>& out,
return false;
}
bool res{true};
int err_count = 0;
int64_t err = 0;
int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
if(out[i] != ref[i])
int64_t o = out[i];
int64_t r = ref[i];
err = std::abs(o - r);
if(err > 0)
{
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i])
<< " != " << static_cast<int>(ref[i]) << std::endl
<< msg << std::endl;
return false;
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i])
<< " != " << static_cast<int>(ref[i]) << std::endl
<< msg << std::endl;
}
res = false;
}
}
return true;
if(!res)
{
std::cout << "max err: " << max_err << std::endl;
}
return res;
}
} // namespace utils
......@@ -191,5 +207,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
#endif
......@@ -402,8 +402,8 @@ template <typename InDataType,
typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename InputInitFun = FillUniform<InDataType>,
typename WeightsInitFun = FillUniform<WeiDataType>>
typename InputInitFun = FillUniformDistribution<InDataType>,
typename WeightsInitFun = FillUniformDistribution<WeiDataType>>
class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>
{
using DeviceConvFwdOp = tensor_operation::device::
......@@ -422,8 +422,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
ConvFwdOpInstance(const ConvParams& params,
bool do_init = true,
const InputInitFun& input_init_f = InputInitFun{},
const WeightsInitFun& weights_init_f = WeightsInitFun{})
const InputInitFun& input_init_f = InputInitFun(),
const WeightsInitFun& weights_init_f = WeightsInitFun())
: BaseType(),
params_{params},
output_spatial_lengths_{params.GetOutputSpatialLengths()},
......@@ -560,8 +560,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
const ConvParams& params_;
const std::vector<ck::index_t> output_spatial_lengths_;
const bool do_init_;
const InputInitFun& input_init_f_;
const WeightsInitFun& weights_init_f_;
InputInitFun input_init_f_;
WeightsInitFun weights_init_f_;
};
} // namespace conv
......
#pragma once
#include <algorithm>
#include <cmath>
#include <random>
#include "data_type.hpp"
......@@ -8,43 +9,53 @@
namespace ck {
namespace utils {
// template <typename T, class Enable = void>
// struct FillUniform;
template <typename T>
struct FillUniformDistribution
{
float a_{-5.f};
float b_{5.f};
// TODO: what's wrong with this specialization???
// err: segmentation fault in mt19937 - infinite loop like.
// template <typename T>
// struct FillUniform<T, typename std::enable_if<std::is_integral<T>::value &&
// !std::is_same<T, bhalf_t>::value>::type>
// {
// int a_{0};
// int b_{5};
// // T a_ = T{0};
// // T b_ = T{5};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(11939);
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
}
};
// template <typename ForwardIter>
// void operator()(ForwardIter first, ForwardIter last) const
// {
// std::mt19937 gen{11939};
// std::uniform_int_distribution<int> dis(a_, b_);
// std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
// }
// };
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
// However this produces segfaults in std::mt19937 which look like inifite loop.
// template <typename T>
// struct FillUniformDistributionIntegerValue
// {
// int a_{-5};
// int b_{5};
//
// template <typename ForwardIter>
// void operator()(ForwardIter first, ForwardIter last) const
// {
// std::mt19937 gen(11939);
// std::uniform_int_distribution<int> dis(a_, b_);
// std::generate(
// first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
// }
// };
// struct FillUniform<T, typename std::enable_if<std::is_floating_point<T>::value ||
// std::is_same<T, bhalf_t>::value>::type>
// Workaround for uniform_int_distribution not working as expected. See note above.<
template <typename T>
struct FillUniform
struct FillUniformDistributionIntegerValue
{
float a_{0};
float b_{5};
float a_{-5.f};
float b_{5.f};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen{11939};
std::uniform_real_distribution<> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
std::mt19937 gen(11939);
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); });
}
};
......
#pragma once
#include <cstdlib>
#include <iostream>
#include <limits>
#include <memory>
#include <stdexcept>
......@@ -78,7 +79,8 @@ class OpInstanceRunEngine
template <typename ReferenceOp = std::function<void()>>
OpInstanceRunEngine(const OpInstanceT& op_instance,
const ReferenceOp& reference_op = ReferenceOp{})
const ReferenceOp& reference_op = ReferenceOp{},
bool do_verification = true)
: op_instance_{op_instance}
{
in_tensors_ = op_instance_.GetInputTensors();
......@@ -88,8 +90,11 @@ class OpInstanceRunEngine
const Tensor<InArgTypes>&...,
Tensor<OutDataType>&>)
{
ref_output_ = op_instance_.GetOutputTensor();
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{});
if(do_verification)
{
ref_output_ = op_instance_.GetOutputTensor();
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{});
}
}
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ =
......@@ -110,6 +115,7 @@ class OpInstanceRunEngine
op_ptr.get(), in_device_buffers_, out_device_buffer_);
if(op_ptr->IsSupportedArgument(argument.get()))
{
std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl;
invoker->Run(argument.get());
out_device_buffer_->FromDevice(out_tensor_->mData.data());
if(!ref_output_)
......@@ -119,9 +125,16 @@ class OpInstanceRunEngine
" You have to provide reference function.");
}
// TODO: enable flexible use of custom check_error functions
res = res && check_err(out_tensor_->mData, ref_output_->mData);
bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData);
std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl;
res = res && inst_res;
out_device_buffer_->SetZero();
}
else
{
std::cout << "Given conv problem is not supported by instance: \n\t>>>>"
<< op_ptr->GetTypeString() << std::endl;
}
}
return res;
}
......@@ -132,7 +145,6 @@ class OpInstanceRunEngine
bool do_verification = false,
bool do_log = false)
{
bool res{true};
ProfileBestConfig best_config;
for(auto& op_ptr : op_ptrs)
......@@ -153,7 +165,7 @@ class OpInstanceRunEngine
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops < best_config.best_tflops)
if(avg_time < best_config.best_avg_time)
{
best_config.best_op_name = op_name;
best_config.best_tflops = tflops;
......@@ -171,7 +183,7 @@ class OpInstanceRunEngine
" You have to provide reference function.");
}
// TODO: enable flexible use of custom check_error functions
res = res && CheckErr(out_tensor_->mData, ref_output_->mData);
CheckErr(out_tensor_->mData, ref_output_->mData);
if(do_log) {}
}
......@@ -223,7 +235,7 @@ class OpInstanceRunEngine
template <typename T>
bool CheckErr(const std::vector<T>& dev_out, const std::vector<T>& ref_out) const
{
return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_);
return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_);
}
};
......
......@@ -20,7 +20,7 @@ include_directories(BEFORE
function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endfunction(add_instance_library INSTANCE_NAME)
......@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd)
......@@ -43,13 +44,14 @@ add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_gemm)
add_subdirectory(conv2d_bwd_weight)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
$<TARGET_OBJECTS:device_gemm_instance>
......@@ -62,19 +64,20 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
device_conv2d.cpp
)
add_library(composablekernels::device_operations ALIAS device_operations)
set(DEV_OPS_INC_DIRS
set(DEV_OPS_INC_DIRS
${PROJECT_SOURCE_DIR}/include/ck/
${PROJECT_SOURCE_DIR}/library/include/ck/
${PROJECT_SOURCE_DIR}/external/include/
)
target_compile_features(device_operations PUBLIC)
set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(device_operations PUBLIC
target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_description>
......@@ -96,9 +99,11 @@ target_include_directories(device_operations PUBLIC
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options(device_operations
PRIVATE --offload-arch=gfx908
target_compile_options(device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
)
# install(TARGETS device_operations LIBRARY DESTINATION lib)
install(TARGETS device_operations
EXPORT device_operationsTargets
......@@ -108,8 +113,8 @@ install(TARGETS device_operations
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
......@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 =
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple<
// clang-format off
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if !CK_WORKAROUND_GITHUB_135
// FIXME: this instance causes numerical errors.
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
#endif
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
......
......@@ -6,7 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
)
set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
)
add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE})
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_fwd_instance)
clang_tidy_check(device_convnd_2d_fwd_instance)
#include <stdlib.h>
#include "config.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using BF16 = ck::bhalf_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances =
std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances =
std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances =
std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment