Commit b097be17 authored by root's avatar root
Browse files

merge changes for upstream/latest update

parents 8a891bbd a49115b9
...@@ -18,12 +18,12 @@ struct GeneratorTensor_0 ...@@ -18,12 +18,12 @@ struct GeneratorTensor_0
template <typename T> template <typename T>
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
int value = 1; T value = 1;
template <typename... Is> template <typename... Is>
T operator()(Is...) T operator()(Is...)
{ {
return ck::type_convert<T>(value); return value;
} }
}; };
......
...@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float v_in; arg.in_element_op_(v_acc, v_acc);
arg.in_element_op_(v_in, v_acc); arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
......
...@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
arg.a_element_op_(a, arg.a_m_k_(m, k)); arg.a_element_op_(a, ck::type_convert<AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(b, arg.b_k_n_(k, n)); arg.b_element_op_(b, ck::type_convert<AccDataType>(arg.b_k_n_(k, n)));
acc += a * b; acc += a * b;
} }
......
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename InDataType, typename OutDataType, typename AccDataType>
struct ReferenceSoftmax : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
{
// std::cout << "debug: scalar dims: ";
for(int i = 0; i < rank; i++)
{
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end())
{
sm_scalar_dims_.push_back(i);
// std::cout << i << ", ";
}
}
// std::cout << std::endl;
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
AccDataType alpha_;
AccDataType beta_;
index_t rank_;
std::vector<index_t> sm_reduce_dims_;
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
std::vector<size_t> scalar_lengths;
for(index_t dim : arg.sm_scalar_dims_)
{
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
}
Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue(
GeneratorTensor_1<AccDataType>{std::numeric_limits<AccDataType>::lowest()});
Tensor<AccDataType> reduce_sum(scalar_lengths);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto to_sm_scalar_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_)
{
sm_scalar_idx.push_back(idx[dim]);
}
return sm_scalar_idx;
};
arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)),
static_cast<AccDataType>(self(idx)));
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl;
Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx)));
});
// LogRangeAsType<float>(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl;
in_stable.ForEach([&](auto& self, auto idx) {
// denominator = sum(exp(x - max(x)))
reduce_sum(to_sm_scalar_idx(idx)) += self(idx);
});
// LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") <<
// std::endl;
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
// reduction along reduce dims
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",")
// << std::endl;
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, rank, sm_reduce_dims};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceSoftmax"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP #pragma once
#define CK_DEVICE_OPERATION_INSTANCE_HPP
#include <stdlib.h> #include <vector>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -23,4 +22,3 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op ...@@ -23,4 +22,3 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple< ...@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOpId> template <ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr< using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -75,14 +75,13 @@ template <typename InDataType, ...@@ -75,14 +75,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise( ...@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise( ...@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise(
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< ...@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOperation> template <ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType = using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr<
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>:: typename reduce_unary_operator<ReduceOperation, true, true>::InElementwiseOperation,
InElementwiseOperation, typename reduce_unary_operator<ReduceOperation, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -77,15 +75,13 @@ template <typename InDataType, ...@@ -77,15 +75,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_multiblock_atomic_add( void add_device_reduce_instance_multiblock_atomic_add(
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>& std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>>& device_op_instances)
device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add(
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \ extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple< ...@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOpId> template <ReduceTensorOp ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr< using deviceReduceThreadWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -61,14 +61,13 @@ template <typename InDataType, ...@@ -61,14 +61,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_threadwise( void add_device_reduce_instance_threadwise(
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceThreadWisePtrType<ReduceOpId>>& device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise( ...@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID( \ #define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise( ...@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise(
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \ #define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \ extern template void add_device_reduce_instance_threadwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID( \ #define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
#ifndef CHECK_ERR_HPP #pragma once
#define CHECK_ERR_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -169,17 +168,34 @@ check_err(const std::vector<T>& out, ...@@ -169,17 +168,34 @@ check_err(const std::vector<T>& out,
return false; return false;
} }
bool res{true};
int err_count = 0;
int64_t err = 0;
int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
if(out[i] != ref[i]) int64_t o = out[i];
int64_t r = ref[i];
err = std::abs(o - r);
if(err > 0)
{ {
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i]) max_err = err > max_err ? err : max_err;
<< " != " << static_cast<int>(ref[i]) << std::endl err_count++;
<< msg << std::endl; if(err_count < 5)
return false; {
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i])
<< " != " << static_cast<int>(ref[i]) << std::endl
<< msg << std::endl;
}
res = false;
} }
} }
return true; if(!res)
{
std::cout << "max err: " << max_err << std::endl;
}
return res;
} }
} // namespace utils } // namespace utils
...@@ -191,5 +207,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) ...@@ -191,5 +207,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " ")); std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os; return os;
} }
#endif
...@@ -402,8 +402,8 @@ template <typename InDataType, ...@@ -402,8 +402,8 @@ template <typename InDataType,
typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough, typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough, typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough, typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename InputInitFun = FillUniform<InDataType>, typename InputInitFun = FillUniformDistribution<InDataType>,
typename WeightsInitFun = FillUniform<WeiDataType>> typename WeightsInitFun = FillUniformDistribution<WeiDataType>>
class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType> class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>
{ {
using DeviceConvFwdOp = tensor_operation::device:: using DeviceConvFwdOp = tensor_operation::device::
...@@ -422,8 +422,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, ...@@ -422,8 +422,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
ConvFwdOpInstance(const ConvParams& params, ConvFwdOpInstance(const ConvParams& params,
bool do_init = true, bool do_init = true,
const InputInitFun& input_init_f = InputInitFun{}, const InputInitFun& input_init_f = InputInitFun(),
const WeightsInitFun& weights_init_f = WeightsInitFun{}) const WeightsInitFun& weights_init_f = WeightsInitFun())
: BaseType(), : BaseType(),
params_{params}, params_{params},
output_spatial_lengths_{params.GetOutputSpatialLengths()}, output_spatial_lengths_{params.GetOutputSpatialLengths()},
...@@ -560,8 +560,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, ...@@ -560,8 +560,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
const ConvParams& params_; const ConvParams& params_;
const std::vector<ck::index_t> output_spatial_lengths_; const std::vector<ck::index_t> output_spatial_lengths_;
const bool do_init_; const bool do_init_;
const InputInitFun& input_init_f_; InputInitFun input_init_f_;
const WeightsInitFun& weights_init_f_; WeightsInitFun weights_init_f_;
}; };
} // namespace conv } // namespace conv
......
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <cmath>
#include <random> #include <random>
#include "data_type.hpp" #include "data_type.hpp"
...@@ -8,43 +9,53 @@ ...@@ -8,43 +9,53 @@
namespace ck { namespace ck {
namespace utils { namespace utils {
// template <typename T, class Enable = void> template <typename T>
// struct FillUniform; struct FillUniformDistribution
{
float a_{-5.f};
float b_{5.f};
// TODO: what's wrong with this specialization??? template <typename ForwardIter>
// err: segmentation fault in mt19937 - infinite loop like. void operator()(ForwardIter first, ForwardIter last) const
// template <typename T> {
// struct FillUniform<T, typename std::enable_if<std::is_integral<T>::value && std::mt19937 gen(11939);
// !std::is_same<T, bhalf_t>::value>::type> std::uniform_real_distribution<float> dis(a_, b_);
// { std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
// int a_{0}; }
// int b_{5}; };
// // T a_ = T{0};
// // T b_ = T{5};
// template <typename ForwardIter> // Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
// void operator()(ForwardIter first, ForwardIter last) const // However this produces segfaults in std::mt19937 which look like inifite loop.
// { // template <typename T>
// std::mt19937 gen{11939}; // struct FillUniformDistributionIntegerValue
// std::uniform_int_distribution<int> dis(a_, b_); // {
// std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); }); // int a_{-5};
// } // int b_{5};
// }; //
// template <typename ForwardIter>
// void operator()(ForwardIter first, ForwardIter last) const
// {
// std::mt19937 gen(11939);
// std::uniform_int_distribution<int> dis(a_, b_);
// std::generate(
// first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
// }
// };
// struct FillUniform<T, typename std::enable_if<std::is_floating_point<T>::value || // Workaround for uniform_int_distribution not working as expected. See note above.<
// std::is_same<T, bhalf_t>::value>::type>
template <typename T> template <typename T>
struct FillUniform struct FillUniformDistributionIntegerValue
{ {
float a_{0}; float a_{-5.f};
float b_{5}; float b_{5.f};
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
std::mt19937 gen{11939}; std::mt19937 gen(11939);
std::uniform_real_distribution<> dis(a_, b_); std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); }); std::generate(
first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); });
} }
}; };
......
#pragma once #pragma once
#include <cstdlib> #include <cstdlib>
#include <iostream>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
...@@ -78,7 +79,8 @@ class OpInstanceRunEngine ...@@ -78,7 +79,8 @@ class OpInstanceRunEngine
template <typename ReferenceOp = std::function<void()>> template <typename ReferenceOp = std::function<void()>>
OpInstanceRunEngine(const OpInstanceT& op_instance, OpInstanceRunEngine(const OpInstanceT& op_instance,
const ReferenceOp& reference_op = ReferenceOp{}) const ReferenceOp& reference_op = ReferenceOp{},
bool do_verification = true)
: op_instance_{op_instance} : op_instance_{op_instance}
{ {
in_tensors_ = op_instance_.GetInputTensors(); in_tensors_ = op_instance_.GetInputTensors();
...@@ -88,8 +90,11 @@ class OpInstanceRunEngine ...@@ -88,8 +90,11 @@ class OpInstanceRunEngine
const Tensor<InArgTypes>&..., const Tensor<InArgTypes>&...,
Tensor<OutDataType>&>) Tensor<OutDataType>&>)
{ {
ref_output_ = op_instance_.GetOutputTensor(); if(do_verification)
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{}); {
ref_output_ = op_instance_.GetOutputTensor();
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{});
}
} }
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{}); AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ = out_device_buffer_ =
...@@ -110,6 +115,7 @@ class OpInstanceRunEngine ...@@ -110,6 +115,7 @@ class OpInstanceRunEngine
op_ptr.get(), in_device_buffers_, out_device_buffer_); op_ptr.get(), in_device_buffers_, out_device_buffer_);
if(op_ptr->IsSupportedArgument(argument.get())) if(op_ptr->IsSupportedArgument(argument.get()))
{ {
std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl;
invoker->Run(argument.get()); invoker->Run(argument.get());
out_device_buffer_->FromDevice(out_tensor_->mData.data()); out_device_buffer_->FromDevice(out_tensor_->mData.data());
if(!ref_output_) if(!ref_output_)
...@@ -119,9 +125,16 @@ class OpInstanceRunEngine ...@@ -119,9 +125,16 @@ class OpInstanceRunEngine
" You have to provide reference function."); " You have to provide reference function.");
} }
// TODO: enable flexible use of custom check_error functions // TODO: enable flexible use of custom check_error functions
res = res && check_err(out_tensor_->mData, ref_output_->mData); bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData);
std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl;
res = res && inst_res;
out_device_buffer_->SetZero(); out_device_buffer_->SetZero();
} }
else
{
std::cout << "Given conv problem is not supported by instance: \n\t>>>>"
<< op_ptr->GetTypeString() << std::endl;
}
} }
return res; return res;
} }
...@@ -132,7 +145,6 @@ class OpInstanceRunEngine ...@@ -132,7 +145,6 @@ class OpInstanceRunEngine
bool do_verification = false, bool do_verification = false,
bool do_log = false) bool do_log = false)
{ {
bool res{true};
ProfileBestConfig best_config; ProfileBestConfig best_config;
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -153,7 +165,7 @@ class OpInstanceRunEngine ...@@ -153,7 +165,7 @@ class OpInstanceRunEngine
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl; << " GB/s, " << op_name << std::endl;
if(tflops < best_config.best_tflops) if(avg_time < best_config.best_avg_time)
{ {
best_config.best_op_name = op_name; best_config.best_op_name = op_name;
best_config.best_tflops = tflops; best_config.best_tflops = tflops;
...@@ -171,7 +183,7 @@ class OpInstanceRunEngine ...@@ -171,7 +183,7 @@ class OpInstanceRunEngine
" You have to provide reference function."); " You have to provide reference function.");
} }
// TODO: enable flexible use of custom check_error functions // TODO: enable flexible use of custom check_error functions
res = res && CheckErr(out_tensor_->mData, ref_output_->mData); CheckErr(out_tensor_->mData, ref_output_->mData);
if(do_log) {} if(do_log) {}
} }
...@@ -223,7 +235,7 @@ class OpInstanceRunEngine ...@@ -223,7 +235,7 @@ class OpInstanceRunEngine
template <typename T> template <typename T>
bool CheckErr(const std::vector<T>& dev_out, const std::vector<T>& ref_out) const bool CheckErr(const std::vector<T>& dev_out, const std::vector<T>& ref_out) const
{ {
return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_); return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_);
} }
}; };
......
...@@ -20,7 +20,7 @@ include_directories(BEFORE ...@@ -20,7 +20,7 @@ include_directories(BEFORE
function(add_instance_library INSTANCE_NAME) function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}") message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN}) add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC) target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
...@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d) ...@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(conv1d_fwd) add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
...@@ -43,13 +44,14 @@ add_subdirectory(convnd_bwd_data) ...@@ -43,13 +44,14 @@ add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_library(device_operations STATIC add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance> $<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance> $<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance> $<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance> $<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance> $<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance> $<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance> $<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
$<TARGET_OBJECTS:device_gemm_instance> $<TARGET_OBJECTS:device_gemm_instance>
...@@ -62,19 +64,20 @@ add_library(device_operations STATIC ...@@ -62,19 +64,20 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance> $<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance> $<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance> $<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
device_conv2d.cpp device_conv2d.cpp
) )
add_library(composablekernels::device_operations ALIAS device_operations) add_library(composablekernels::device_operations ALIAS device_operations)
set(DEV_OPS_INC_DIRS set(DEV_OPS_INC_DIRS
${PROJECT_SOURCE_DIR}/include/ck/ ${PROJECT_SOURCE_DIR}/include/ck/
${PROJECT_SOURCE_DIR}/library/include/ck/ ${PROJECT_SOURCE_DIR}/library/include/ck/
${PROJECT_SOURCE_DIR}/external/include/ ${PROJECT_SOURCE_DIR}/external/include/
) )
target_compile_features(device_operations PUBLIC) target_compile_features(device_operations PUBLIC)
set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(device_operations PUBLIC target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_description> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_description>
...@@ -96,9 +99,11 @@ target_include_directories(device_operations PUBLIC ...@@ -96,9 +99,11 @@ target_include_directories(device_operations PUBLIC
#once new arches are enabled make this an option on the main cmake file #once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported # and pass down here to be exported
target_compile_options(device_operations target_compile_options(device_operations PRIVATE
PRIVATE --offload-arch=gfx908 --offload-arch=gfx908
--offload-arch=gfx90a
) )
# install(TARGETS device_operations LIBRARY DESTINATION lib) # install(TARGETS device_operations LIBRARY DESTINATION lib)
install(TARGETS device_operations install(TARGETS device_operations
EXPORT device_operationsTargets EXPORT device_operationsTargets
...@@ -108,8 +113,8 @@ install(TARGETS device_operations ...@@ -108,8 +113,8 @@ install(TARGETS device_operations
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
) )
install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
install(EXPORT device_operationsTargets install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel:: NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
) )
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
...@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in ...@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 = ...@@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 =
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple<
// clang-format off // clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if !CK_WORKAROUND_GITHUB_135
// FIXME: this instance causes numerical errors.
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
#endif
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
......
...@@ -6,7 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ...@@ -6,7 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
) )
set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
)
add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE})
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_conv2d_fwd_instance) clang_tidy_check(device_conv2d_fwd_instance)
clang_tidy_check(device_convnd_2d_fwd_instance)
#include <stdlib.h>
#include "config.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using BF16 = ck::bhalf_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances =
std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances =
std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances =
std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment