Commit e4e99a49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new utilities to shorten codes

parent 7acbf104
...@@ -6,27 +6,32 @@ ...@@ -6,27 +6,32 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstdlib> #include <cstdlib>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <iostream>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp" #include "ck/utility/span.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/ranges.hpp"
namespace ck { namespace ck {
namespace utils { namespace utils {
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_floating_point<T>::value && !std::is_same<T, half_t>::value, typename std::enable_if<
bool>::type std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
check_err(const std::vector<T>& out, std::is_floating_point_v<ranges::range_value_t<Range>> &&
const std::vector<T>& ref, !std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5, double rtol = 1e-5,
double atol = 3e-6) double atol = 3e-6)
...@@ -44,15 +49,17 @@ check_err(const std::vector<T>& out, ...@@ -44,15 +49,17 @@ check_err(const std::vector<T>& out,
double max_err = std::numeric_limits<double>::min(); double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
err = std::abs(out[i] - ref[i]); const double o = *std::next(std::begin(out), i);
if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) const double r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
res = false; res = false;
} }
...@@ -64,10 +71,13 @@ check_err(const std::vector<T>& out, ...@@ -64,10 +71,13 @@ check_err(const std::vector<T>& out,
return res; return res;
} }
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_same<T, bhalf_t>::value, bool>::type typename std::enable_if<
check_err(const std::vector<T>& out, std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
const std::vector<T>& ref, std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
...@@ -86,9 +96,9 @@ check_err(const std::vector<T>& out, ...@@ -86,9 +96,9 @@ check_err(const std::vector<T>& out,
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
double o = type_convert<float>(out[i]); const double o = type_convert<float>(*std::next(std::begin(out), i));
double r = type_convert<float>(ref[i]); const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
...@@ -108,10 +118,13 @@ check_err(const std::vector<T>& out, ...@@ -108,10 +118,13 @@ check_err(const std::vector<T>& out,
return res; return res;
} }
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_same_v<T, half_t>, bool>::type typename std::enable_if<
check_err(span<const T> out, std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
span<const T> ref, std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
...@@ -126,12 +139,12 @@ check_err(span<const T> out, ...@@ -126,12 +139,12 @@ check_err(span<const T> out,
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<T>::min(); double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
double o = type_convert<float>(out[i]); const double o = type_convert<float>(*std::next(std::begin(out), i));
double r = type_convert<float>(ref[i]); const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
...@@ -151,26 +164,17 @@ check_err(span<const T> out, ...@@ -151,26 +164,17 @@ check_err(span<const T> out,
return res; return res;
} }
template <typename T> template <typename Range, typename RefRange>
typename std::enable_if<std::is_same<T, half_t>::value, bool>::type std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
check_err(const std::vector<T>& out, std::is_integral_v<ranges::range_value_t<Range>> &&
const std::vector<T>& ref, !std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
return check_err(span<const T>{out}, span<const T>{ref}, msg, rtol, atol);
}
template <typename T>
std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t> || std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif #endif
, ,
bool> bool>
check_err(const std::vector<T>& out, check_err(const Range& out,
const std::vector<T>& ref, const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double = 0, double = 0,
double atol = 0) double atol = 0)
...@@ -188,9 +192,9 @@ check_err(const std::vector<T>& out, ...@@ -188,9 +192,9 @@ check_err(const std::vector<T>& out,
int64_t max_err = std::numeric_limits<int64_t>::min(); int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
int64_t o = out[i]; const int64_t o = *std::next(std::begin(out), i);
int64_t r = ref[i]; const int64_t r = *std::next(std::begin(ref), i);
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol) if(err > atol)
{ {
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include "ck/library/utility/buffer.hpp"
template <typename T> template <typename T>
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
{ {
...@@ -18,7 +20,7 @@ struct DeviceMem ...@@ -18,7 +20,7 @@ struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() = delete;
DeviceMem(std::size_t mem_size); DeviceMem(std::size_t mem_size);
void* GetDeviceBuffer() const; ck::utils::mutable_buffer GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
void FromDevice(void* p) const; void FromDevice(void* p) const;
......
...@@ -72,6 +72,16 @@ struct FillUniformDistributionIntegerValue ...@@ -72,6 +72,16 @@ struct FillUniformDistributionIntegerValue
std::generate( std::generate(
first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); }); first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); });
} }
template <typename ForwardRange>
auto operator()(ForwardRange&& range)
-> std::void_t<decltype(std::declval<FillUniformDistributionIntegerValue>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
}; };
template <typename T> template <typename T>
......
...@@ -25,16 +25,15 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in, ...@@ -25,16 +25,15 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v = 0; float v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei.GetLengths()[2]; ++y)
{ {
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei.GetLengths()[3]; ++x)
{ {
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in.GetLengths()[2] && wi >= 0 && wi < in.GetLengths()[3])
wi < in.mDesc.GetLengths()[3])
{ {
v += ck::type_convert<float>(in(n, c, hi, wi)) * v += ck::type_convert<float>(in(n, c, hi, wi)) *
ck::type_convert<float>(wei(k, c, y, x)); ck::type_convert<float>(wei(k, c, y, x));
...@@ -45,11 +44,9 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in, ...@@ -45,11 +44,9 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
out(n, k, ho, wo) = ck::type_convert<TOut>(v); out(n, k, ho, wo) = ck::type_convert<TOut>(v);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(
out.mDesc.GetLengths()[0], f_nchw, out.GetLengths()[0], out.GetLengths()[1], out.GetLengths()[2], out.GetLengths()[3])(
out.mDesc.GetLengths()[1], std::thread::hardware_concurrency());
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
} }
template <typename TIn, template <typename TIn,
...@@ -72,13 +69,13 @@ void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in, ...@@ -72,13 +69,13 @@ void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in,
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
const auto Di = in.mDesc.GetLengths()[1]; const auto Di = in.GetLengths()[1];
const auto Hi = in.mDesc.GetLengths()[2]; const auto Hi = in.GetLengths()[2];
const auto Wi = in.mDesc.GetLengths()[3]; const auto Wi = in.GetLengths()[3];
const auto Z = wei.mDesc.GetLengths()[1]; const auto Z = wei.GetLengths()[1];
const auto Y = wei.mDesc.GetLengths()[2]; const auto Y = wei.GetLengths()[2];
const auto X = wei.mDesc.GetLengths()[3]; const auto X = wei.GetLengths()[3];
const auto C = wei.mDesc.GetLengths()[4]; const auto C = wei.GetLengths()[4];
auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) { auto f_ndhwc = [&](auto n, auto do_tmp, auto ho_tmp, auto wo_tmp, auto k) {
// do__ must be converted to signed integer, otherwise zmin might be wrong in cases // do__ must be converted to signed integer, otherwise zmin might be wrong in cases
...@@ -144,9 +141,9 @@ void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in, ...@@ -144,9 +141,9 @@ void host_conv3d_ndhwc_kzyxc_ndhwk(const Tensor<TIn>& in,
}; };
make_ParallelTensorFunctor(f_ndhwc, make_ParallelTensorFunctor(f_ndhwc,
out.mDesc.GetLengths()[0], out.GetLengths()[0],
out.mDesc.GetLengths()[1], out.GetLengths()[1],
out.mDesc.GetLengths()[2], out.GetLengths()[2],
out.mDesc.GetLengths()[3], out.GetLengths()[3],
out.mDesc.GetLengths()[4])(std::thread::hardware_concurrency() - 4); out.GetLengths()[4])(std::thread::hardware_concurrency() - 4);
} }
...@@ -19,7 +19,7 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, ...@@ -19,7 +19,7 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1]; const int K = a_m_k.GetLengths()[1];
float v_acc = 0; float v_acc = 0;
...@@ -41,7 +41,6 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, ...@@ -41,7 +41,6 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
c_m_n(m, n) = v_c; c_m_n(m, n) = v_c;
}; };
make_ParallelTensorFunctor(f_mk_kn_mn, make_ParallelTensorFunctor(f_mk_kn_mn, c_m_n.GetLengths()[0], c_m_n.GetLengths()[1])(
c_m_n.mDesc.GetLengths()[0], std::thread::hardware_concurrency());
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
} }
...@@ -108,13 +108,13 @@ struct ReductionHost ...@@ -108,13 +108,13 @@ struct ReductionHost
std::vector<std::array<size_t, NumReduceDim>> reduce_dim_indexes; std::vector<std::array<size_t, NumReduceDim>> reduce_dim_indexes;
std::vector<std::array<size_t, NumInvariantDim>> invariant_dim_indexes; std::vector<std::array<size_t, NumInvariantDim>> invariant_dim_indexes;
ReductionHost(HostTensorDescriptor& inDesc, ReductionHost(const HostTensorDescriptor& inDesc,
HostTensorDescriptor& outDesc, const HostTensorDescriptor& outDesc,
const std::vector<int>& invariantDims_, const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_) const std::vector<int>& reduceDims_)
{ {
// this->outLengths = to_int_vector(outDesc.GetLengths()); // this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides(); this->outStrides.assign(outDesc.GetStrides().begin(), outDesc.GetStrides().end());
this->invariantDims = invariantDims_; this->invariantDims = invariantDims_;
this->reduceDims = reduceDims_; this->reduceDims = reduceDims_;
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp" #include "ck/utility/span.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/ranges.hpp"
template <typename Range> template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
{ {
...@@ -84,10 +88,10 @@ struct HostTensorDescriptor ...@@ -84,10 +88,10 @@ struct HostTensorDescriptor
this->CalculateStrides(); this->CalculateStrides();
} }
template <typename Range, template <typename Lengths,
typename = std::enable_if_t< typename = std::enable_if_t<
std::is_convertible_v<decltype(*std::begin(std::declval<Range>())), std::size_t>>> std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t>>>
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
...@@ -102,13 +106,12 @@ struct HostTensorDescriptor ...@@ -102,13 +106,12 @@ struct HostTensorDescriptor
{ {
} }
template < template <typename Lengths,
typename Range1, typename Strides,
typename Range2, typename = std::enable_if_t<
typename = std::enable_if_t< std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
std::is_convertible_v<decltype(*std::begin(std::declval<Range1>())), std::size_t> && std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>>>
std::is_convertible_v<decltype(*std::begin(std::declval<Range2>())), std::size_t>>> HostTensorDescriptor(const Lengths& lens, const Strides& strides)
HostTensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
} }
...@@ -117,8 +120,8 @@ struct HostTensorDescriptor ...@@ -117,8 +120,8 @@ struct HostTensorDescriptor
std::size_t GetElementSize() const; std::size_t GetElementSize() const;
std::size_t GetElementSpaceSize() const; std::size_t GetElementSpaceSize() const;
const std::vector<std::size_t>& GetLengths() const; ck::span<const std::size_t> GetLengths() const;
const std::vector<std::size_t>& GetStrides() const; ck::span<const std::size_t> GetStrides() const;
template <typename... Is> template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
...@@ -234,37 +237,41 @@ auto make_ParallelTensorFunctor(F f, Xs... xs) ...@@ -234,37 +237,41 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
} }
template <typename T> template <typename T>
struct Tensor struct Tensor : private HostTensorDescriptor, private std::vector<T>
{ {
using Descriptor = HostTensorDescriptor; using Descriptor = HostTensorDescriptor;
using Data = std::vector<T>; using Data = std::vector<T>;
template <typename X> template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(std::initializer_list<X> lens) : Descriptor(lens), Data(GetElementSpaceSize())
{ {
} }
template <typename X> template <typename X, typename Y>
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: Descriptor(lens, strides), Data(GetElementSpaceSize())
{ {
} }
template <typename X, typename Y> template <typename Lengths>
Tensor(std::vector<X> lens, std::vector<Y> strides) Tensor(const Lengths& lens) : Descriptor(lens), Data(GetElementSpaceSize())
: mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
{ {
} }
Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} template <typename Lengths, typename Strides>
Tensor(const Lengths& lens, const Strides& strides)
: Descriptor(lens, strides), Data(GetElementSpaceSize())
{
}
Tensor(const Descriptor& desc) : Descriptor(desc), Data(GetElementSpaceSize()) {}
template <typename OutT> template <typename OutT>
Tensor<OutT> CopyAsType() const Tensor<OutT> CopyAsType() const
{ {
Tensor<OutT> ret(mDesc); Tensor<OutT> ret(GetDesc());
for(size_t i = 0; i < mData.size(); i++) ck::ranges::transform(
{ *this, ret.begin(), [](auto value) { return ck::type_convert<OutT>(value); });
ret.mData[i] = ck::type_convert<OutT>(mData[i]);
}
return ret; return ret;
} }
...@@ -282,36 +289,28 @@ struct Tensor ...@@ -282,36 +289,28 @@ struct Tensor
{ {
} }
decltype(auto) GetLengths() const { return mDesc.GetLengths(); } const Descriptor& GetDesc() const noexcept { return *this; }
decltype(auto) GetStrides() const { return mDesc.GetStrides(); } using Descriptor::GetElementSize;
using Descriptor::GetElementSpaceSize;
using Descriptor::GetLengths;
using Descriptor::GetNumOfDimension;
using Descriptor::GetStrides;
std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); } std::size_t GetMemorySize() const { return sizeof(T) * GetElementSpaceSize(); }
std::size_t GetElementSize() const { return mDesc.GetElementSize(); } void SetZero() { ck::ranges::fill<T>(*this, 0); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
void SetZero()
{
for(auto& v : mData)
{
v = T{0};
}
}
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{ {
if(rank == mDesc.GetNumOfDimension()) if(rank == GetNumOfDimension())
{ {
f(*this, idx); f(*this, idx);
return; return;
} }
// else // else
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) for(size_t i = 0; i < GetLengths()[rank]; i++)
{ {
idx[rank] = i; idx[rank] = i;
ForEach_impl(std::forward<F>(f), idx, rank + 1); ForEach_impl(std::forward<F>(f), idx, rank + 1);
...@@ -321,20 +320,20 @@ struct Tensor ...@@ -321,20 +320,20 @@ struct Tensor
template <typename F> template <typename F>
void ForEach(F&& f) void ForEach(F&& f)
{ {
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0); std::vector<size_t> idx(GetNumOfDimension(), 0);
ForEach_impl(std::forward<F>(f), idx, size_t(0)); ForEach_impl(std::forward<F>(f), idx, size_t(0));
} }
template <typename F> template <typename F>
void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
{ {
if(rank == mDesc.GetNumOfDimension()) if(rank == GetNumOfDimension())
{ {
f(*this, idx); f(*this, idx);
return; return;
} }
// else // else
for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) for(size_t i = 0; i < GetLengths()[rank]; i++)
{ {
idx[rank] = i; idx[rank] = i;
ForEach_impl(std::forward<const F>(f), idx, rank + 1); ForEach_impl(std::forward<const F>(f), idx, rank + 1);
...@@ -344,40 +343,37 @@ struct Tensor ...@@ -344,40 +343,37 @@ struct Tensor
template <typename F> template <typename F>
void ForEach(const F&& f) const void ForEach(const F&& f) const
{ {
std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0); std::vector<size_t> idx(GetNumOfDimension(), 0);
ForEach_impl(std::forward<const F>(f), idx, size_t(0)); ForEach_impl(std::forward<const F>(f), idx, size_t(0));
} }
template <typename G> template <typename G>
void GenerateTensorValue(G g, std::size_t num_thread = 1) void GenerateTensorValue(G g, std::size_t num_thread = 1)
{ {
switch(mDesc.GetNumOfDimension()) switch(GetNumOfDimension())
{ {
case 1: { case 1: {
auto f = [&](auto i) { (*this)(i) = g(i); }; auto f = [&](auto i) { (*this)(i) = g(i); };
make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread); make_ParallelTensorFunctor(f, GetLengths()[0])(num_thread);
break; break;
} }
case 2: { case 2: {
auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread); make_ParallelTensorFunctor(f, GetLengths()[0], GetLengths()[1])(num_thread);
break; break;
} }
case 3: { case 3: {
auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(f, GetLengths()[0], GetLengths()[1], GetLengths()[2])(
f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread); num_thread);
break; break;
} }
case 4: { case 4: {
auto f = [&](auto i0, auto i1, auto i2, auto i3) { auto f = [&](auto i0, auto i1, auto i2, auto i3) {
(*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(
mDesc.GetLengths()[0], f, GetLengths()[0], GetLengths()[1], GetLengths()[2], GetLengths()[3])(num_thread);
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3])(num_thread);
break; break;
} }
case 5: { case 5: {
...@@ -385,11 +381,11 @@ struct Tensor ...@@ -385,11 +381,11 @@ struct Tensor
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0], GetLengths()[0],
mDesc.GetLengths()[1], GetLengths()[1],
mDesc.GetLengths()[2], GetLengths()[2],
mDesc.GetLengths()[3], GetLengths()[3],
mDesc.GetLengths()[4])(num_thread); GetLengths()[4])(num_thread);
break; break;
} }
case 6: { case 6: {
...@@ -397,12 +393,12 @@ struct Tensor ...@@ -397,12 +393,12 @@ struct Tensor
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5); (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5);
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0], GetLengths()[0],
mDesc.GetLengths()[1], GetLengths()[1],
mDesc.GetLengths()[2], GetLengths()[2],
mDesc.GetLengths()[3], GetLengths()[3],
mDesc.GetLengths()[4], GetLengths()[4],
mDesc.GetLengths()[5])(num_thread); GetLengths()[5])(num_thread);
break; break;
} }
default: throw std::runtime_error("unspported dimension"); default: throw std::runtime_error("unspported dimension");
...@@ -412,59 +408,58 @@ struct Tensor ...@@ -412,59 +408,58 @@ struct Tensor
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; return (*this)[GetOffsetFromMultiIndex(is...)];
} }
template <typename... Is> template <typename... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; return (*this)[GetOffsetFromMultiIndex(is...)];
} }
T& operator()(std::vector<std::size_t> idx) T& operator()(std::vector<std::size_t> idx) { return (*this)[GetOffsetFromMultiIndex(idx)]; }
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
const T& operator()(std::vector<std::size_t> idx) const const T& operator()(std::vector<std::size_t> idx) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; return (*this)[GetOffsetFromMultiIndex(idx)];
} }
typename Data::iterator begin() { return mData.begin(); } using Data::begin;
using Data::data;
typename Data::iterator end() { return mData.end(); } using Data::end;
using Data::size;
typename Data::pointer data() { return mData.data(); }
typename Data::const_iterator begin() const { return mData.begin(); }
typename Data::const_iterator end() const { return mData.end(); }
typename Data::const_pointer data() const { return mData.data(); }
typename Data::size_type size() const { return mData.size(); }
template <typename U = T> template <typename U = T>
auto AsSpan() const auto AsSpan() const
{ {
constexpr std::size_t FromSize = sizeof(T); using namespace ck::literals;
constexpr std::size_t ToSize = sizeof(U);
constexpr std::size_t ToSize = sizeof(U);
using Element = std::add_const_t<std::remove_reference_t<U>>; using Element = std::add_const_t<std::remove_reference_t<U>>;
return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
if(GetMemorySize() % ToSize != 0)
{
return ck::span<Element>(nullptr, 0_uz);
}
return ck::span<Element>{reinterpret_cast<Element*>(data()), GetMemorySize() / ToSize};
} }
template <typename U = T> template <typename U = T>
auto AsSpan() auto AsSpan()
{ {
constexpr std::size_t FromSize = sizeof(T); using namespace ck::literals;
constexpr std::size_t ToSize = sizeof(U);
constexpr std::size_t ToSize = sizeof(U);
using Element = std::remove_reference_t<U>; using Element = std::remove_reference_t<U>;
return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
}
Descriptor mDesc; if(GetMemorySize() % ToSize != 0)
Data mData; {
return ck::span<Element>(nullptr, 0_uz);
}
return ck::span<Element>{reinterpret_cast<Element*>(data()), GetMemorySize() / ToSize};
}
}; };
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iterator>
#include <utility>
#include "ck/utility/type.hpp"
namespace ck {
template <typename T>
using iter_value_t = typename std::iterator_traits<remove_cvref_t<T>>::value_type;
template <typename T>
using iter_reference_t = decltype(*std::declval<T&>());
template <typename T>
using iter_difference_t = typename std::iterator_traits<remove_cvref_t<T>>::difference_type;
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iterator>
#include <numeric>
namespace ck {
template <typename ForwardIterator, typename Size, typename T, typename BinaryOperation>
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op)
-> decltype(std::accumulate(first, std::next(first, count), init, op))
{
return std::accumulate(first, std::next(first, count), init, op);
}
} // namespace ck
...@@ -103,8 +103,7 @@ class OpInstanceRunEngine ...@@ -103,8 +103,7 @@ class OpInstanceRunEngine
} }
} }
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{}); AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) * out_device_buffer_ = std::make_unique<DeviceMem>(out_tensor_->GetMemorySize());
out_tensor_->mDesc.GetElementSpaceSize());
out_device_buffer_->SetZero(); out_device_buffer_->SetZero();
} }
...@@ -219,10 +218,7 @@ class OpInstanceRunEngine ...@@ -219,10 +218,7 @@ class OpInstanceRunEngine
void AllocateDeviceInputTensorsImpl() void AllocateDeviceInputTensorsImpl()
{ {
const auto& ts = std::get<Index>(in_tensors_); const auto& ts = std::get<Index>(in_tensors_);
in_device_buffers_ in_device_buffers_.emplace_back(std::make_unique<DeviceMem>(ts->GetMemorySize()))
.emplace_back(
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
ts->mDesc.GetElementSpaceSize()))
->ToDevice(ts->mData.data()); ->ToDevice(ts->mData.data());
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iterator>
#include <type_traits>
#include <utility>
#include "ck/library/utility/iterator.hpp"
namespace ck {
namespace ranges {
template <typename R>
using iterator_t = decltype(std::begin(std::declval<R&>()));
template <typename R>
using sentinel_t = decltype(std::end(std::declval<R&>()));
template <typename R>
using range_size_t = decltype(std::size(std::declval<R&>()));
template <typename R>
using range_difference_t = ck::iter_difference_t<ranges::iterator_t<R>>;
template <typename R>
using range_value_t = iter_value_t<ranges::iterator_t<R>>;
template <typename R>
using range_reference_t = iter_reference_t<ranges::iterator_t<R>>;
template <typename T, typename = void>
struct is_range : std::false_type
{
};
template <typename T>
struct is_range<
T,
std::void_t<decltype(std::begin(std::declval<T&>())), decltype(std::end(std::declval<T&>()))>>
: std::true_type
{
};
template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;
template <typename T, typename = void>
struct is_sized_range : std::false_type
{
};
template <typename T>
struct is_sized_range<T, std::void_t<decltype(std::size(std::declval<T&>()))>>
: std::bool_constant<is_range_v<T>>
{
};
template <typename T>
inline constexpr bool is_sized_range_v = is_sized_range<T>::value;
template <typename Cont, typename Range, typename... Args>
auto to(Range&& range, Args&&... args) -> decltype(Cont{std::begin(std::forward<Range>(range)),
std::end(std::forward<Range>(range)),
std::forward<Args>(args)...})
{
return Cont{std::begin(std::forward<Range>(range)),
std::end(std::forward<Range>(range)),
std::forward<Args>(args)...};
}
} // namespace ranges
} // namespace ck
...@@ -10,7 +10,10 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) ...@@ -10,7 +10,10 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
} }
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } ck::utils::mutable_buffer DeviceMem::GetDeviceBuffer() const
{
return ck::utils::mutable_buffer(mpDeviceBuf, mMemSize);
}
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; } std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
......
...@@ -36,9 +36,9 @@ std::size_t HostTensorDescriptor::GetElementSpaceSize() const ...@@ -36,9 +36,9 @@ std::size_t HostTensorDescriptor::GetElementSpaceSize() const
return space; return space;
} }
const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; } ck::span<const std::size_t> HostTensorDescriptor::GetLengths() const { return mLens; }
const std::vector<std::size_t>& HostTensorDescriptor::GetStrides() const { return mStrides; } ck::span<const std::size_t> HostTensorDescriptor::GetStrides() const { return mStrides; }
std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
{ {
......
...@@ -8,13 +8,16 @@ ...@@ -8,13 +8,16 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -105,21 +108,21 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -105,21 +108,21 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), Row>::value) if constexpr(std::is_same_v<decltype(layout), Row>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -144,12 +147,12 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -144,12 +147,12 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
Tensor<RefAcc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<RefAcc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{})); Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; std::cout << "a0_g_m_k: " << a0_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl; std::cout << "d0_g_m_n: " << d0_g_m_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl; std::cout << "d1_g_m_o: " << d1_g_m_o.GetDesc() << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -169,19 +172,18 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -169,19 +172,18 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); DeviceMem a0_g_m_k_device_buf(a0_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_g_m_n_device_buf(d0_g_m_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize()); DeviceMem d1_g_m_o_device_buf(d1_g_m_o.GetMemorySize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * DeviceMem e1_g_m_o_device_buf(e1_g_m_o_device_result.GetMemorySize());
e1_g_m_o_device_result.mDesc.GetElementSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); a0_g_m_k_device_buf.ToDevice(a0_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data()); d0_g_m_n_device_buf.ToDevice(d0_g_m_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data()); d1_g_m_o_device_buf.ToDevice(d1_g_m_o.data());
auto a0_element_op = A0ElementOp{}; auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -263,38 +265,40 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -263,38 +265,40 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
using ck::utils::to_array;
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr =
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(a0_g_m_k_device_buf.GetDeviceBuffer(),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), b0_g_k_n_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()}, to_array({d0_g_m_n_device_buf.GetDeviceBuffer()}),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), b1_g_n_o_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()}, to_array({d1_g_m_o_device_buf.GetDeviceBuffer()}),
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()), e1_g_m_o_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
O, O,
BatchCount, BatchCount,
StrideA0, StrideA0,
StrideB0, StrideB0,
std::array<ck::index_t, 1>{StrideD0}, to_array({StrideD0}),
StrideB1, StrideB1,
std::array<ck::index_t, 1>{StrideD1}, to_array({StrideD1}),
StrideE1, StrideE1,
BatchStrideA0, BatchStrideA0,
BatchStrideB0, BatchStrideB0,
std::array<ck::index_t, 1>{BatchStrideD0}, to_array({BatchStrideD0}),
BatchStrideB1, BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1}, to_array({BatchStrideD1}),
BatchStrideE1, BatchStrideE1,
a0_element_op, a0_element_op,
b0_element_op, b0_element_op,
cde0_element_op, cde0_element_op,
b1_element_op, b1_element_op,
cde1_element_op); cde1_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -328,18 +332,17 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -328,18 +332,17 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result.mData, pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
e1_g_m_o_host_result.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",") std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",") std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result, ",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -6,17 +6,18 @@ ...@@ -6,17 +6,18 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -99,21 +100,21 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -99,21 +100,21 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), Row>::value) if constexpr(std::is_same_v<decltype(layout), Row>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -131,10 +132,10 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -131,10 +132,10 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
// Host verification: Output of Gemm0 is input A of Gemm1 // Host verification: Output of Gemm0 is input A of Gemm1
Tensor<ADataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<ADataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; std::cout << "c_g_m_o: " << c_g_m_o_host_result.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -160,14 +161,14 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -160,14 +161,14 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); DeviceMem c_g_m_o_device_buf(c_g_m_o_device_result.GetMemorySize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -226,29 +227,28 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -226,29 +227,28 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(a_g_m_k_device_buf.GetDeviceBuffer(),
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), b0_g_k_n_device_buf.GetDeviceBuffer(),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), b1_g_n_o_device_buf.GetDeviceBuffer(),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), c_g_m_o_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()), M,
M, N,
N, K,
K, O,
O, BatchCount,
BatchCount, StrideA,
StrideA, StrideB0,
StrideB0, StrideB1,
StrideB1, StrideC,
StrideC, BatchStrideA,
BatchStrideA, BatchStrideB0,
BatchStrideB0, BatchStrideB1,
BatchStrideB1, BatchStrideC,
BatchStrideC, a_element_op,
a_element_op, b0_element_op,
b0_element_op, acc0_element_op,
acc0_element_op, b1_element_op,
b1_element_op, c_element_op);
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -281,24 +281,20 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -281,24 +281,20 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.data());
pass = pass & pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result);
ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k, ",") << std::endl;
<< std::endl; LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o, ",") << std::endl;
<< std::endl;
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",")
<< std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result.mData, ",") std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",") std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result, ",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -6,17 +6,18 @@ ...@@ -6,17 +6,18 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -44,21 +45,21 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -44,21 +45,21 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
bool pass = true; bool pass = true;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if constexpr(is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -71,9 +72,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -71,9 +72,9 @@ bool profile_batched_gemm_impl(int do_verification,
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.GetDesc() << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; std::cout << "c_g_m_n: " << c_g_m_n_host_result.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -115,13 +116,13 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -115,13 +116,13 @@ bool profile_batched_gemm_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_g_k_n.GetMemorySize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(c_g_m_n_device_result.GetMemorySize());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.data());
c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.data());
using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
BLayout, BLayout,
...@@ -148,9 +149,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -148,9 +149,9 @@ bool profile_batched_gemm_impl(int do_verification,
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), c_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
...@@ -177,7 +178,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -177,7 +178,7 @@ bool profile_batched_gemm_impl(int do_verification,
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; std::size_t flop = 2_uz * BatchCount * M * N * K;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) * sizeof(CDataType) * M * N) *
...@@ -200,19 +201,17 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -200,19 +201,17 @@ bool profile_batched_gemm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.data());
pass = pass & pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_g_k_n, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(std::cout << "c_device: ", c_g_m_n_device_result, ",")
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -6,18 +6,19 @@ ...@@ -6,18 +6,19 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -106,21 +107,21 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -106,21 +107,21 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
const int BatchCount = G0 * G1; const int BatchCount = G0 * G1;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), Row>::value) if constexpr(std::is_same_v<decltype(layout), Row>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -131,22 +132,17 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -131,22 +132,17 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o( Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_host_result( Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
// Host verification: Output of Gemm0 is input A of Gemm1 // Host verification: Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O}, Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}, {M * O, O, 1});
std::vector<int>{M * O, O, 1});
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.GetDesc() << std::endl;
std::srand(1); // work around test flakiness std::srand(1); // work around test flakiness
switch(init_method) switch(init_method)
...@@ -180,15 +176,14 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -180,15 +176,14 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) * DeviceMem c_gs_ms_os_device_buf(c_gs_ms_os_device_result.GetMemorySize());
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -264,29 +259,28 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -264,29 +259,28 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(a_g_m_k_device_buf.GetDeviceBuffer(),
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), b0_g_k_n_device_buf.GetDeviceBuffer(),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), b1_g_n_o_device_buf.GetDeviceBuffer(),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), c_gs_ms_os_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()), M,
M, N,
N, K,
K, O,
O, BatchCount,
BatchCount, c_gs_ms_os_lengths,
c_gs_ms_os_lengths, c_gs_ms_os_strides,
c_gs_ms_os_strides, StrideA,
StrideA, StrideB0,
StrideB0, StrideB1,
StrideB1, BatchStrideA,
BatchStrideA, BatchStrideB0,
BatchStrideB0, BatchStrideB1,
BatchStrideB1, a_element_op,
a_element_op, b0_element_op,
b0_element_op, acc0_element_op,
acc0_element_op, b1_element_op,
b1_element_op, c_element_op);
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -319,25 +313,21 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -319,25 +313,21 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
if(do_verification) if(do_verification)
{ {
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.data());
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData, pass =
c_gs_ms_os_host_result.mData); pass & ck::utils::check_err(c_gs_ms_os_device_result, c_gs_ms_os_host_result);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k, ",") << std::endl;
<< std::endl; LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o, ",") << std::endl;
<< std::endl;
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",")
<< std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ", LogRangeAsType<float>(
c_gs_ms_os_device_result.mData, std::cout << "c_gs_ms_os_device_result : ", c_gs_ms_os_device_result, ",")
",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -4,17 +4,18 @@ ...@@ -4,17 +4,18 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -73,20 +74,20 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -73,20 +74,20 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
{ {
bool pass = true; bool pass = true;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz});
std::vector<std::size_t>({row * stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride});
std::vector<std::size_t>({col * stride, 1, stride}));
} }
}; };
...@@ -95,23 +96,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -95,23 +96,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor({BatchCount, M}));
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor({BatchCount, M}));
Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor({BatchCount, M}));
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor({BatchCount, M}));
Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.GetDesc() << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; std::cout << "c_g_m_n: " << c_g_m_n_host_result.GetDesc() << std::endl;
std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; std::cout << "d0_g_m: " << d0_g_m_host_result.GetDesc() << std::endl;
std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; std::cout << "d1_g_m: " << d1_g_m_host_result.GetDesc() << std::endl;
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
switch(init_method) switch(init_method)
...@@ -194,19 +191,17 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -194,19 +191,17 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_g_k_n.GetMemorySize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(c_g_m_n_device_result.GetMemorySize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * DeviceMem reduce0_device_buf(d0_g_m_device_result.GetMemorySize());
d0_g_m_device_result.mDesc.GetElementSpaceSize()); DeviceMem reduce1_device_buf(d1_g_m_device_result.GetMemorySize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(), std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()}; reduce1_device_buf.GetDeviceBuffer()};
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs; std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
...@@ -293,7 +288,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -293,7 +288,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; std::size_t flop = 2_uz * BatchCount * M * N * K;
std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K +
sizeof(BDataType) * BatchCount * K * N + sizeof(BDataType) * BatchCount * K * N +
sizeof(CDataType) * BatchCount * M * N; sizeof(CDataType) * BatchCount * M * N;
...@@ -315,16 +310,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -315,16 +310,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.data());
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.data());
bool c_error = bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData); bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
bool d0_error = bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
ck::utils::check_err(d0_g_m_device_result.mData, d0_g_m_host_result.mData);
bool d1_error =
ck::utils::check_err(d1_g_m_device_result.mData, d1_g_m_host_result.mData);
pass = pass && (c_error == true); pass = pass && (c_error == true);
pass = pass && (d0_error == true); pass = pass && (d0_error == true);
...@@ -332,22 +324,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -332,22 +324,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_g_k_n, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(std::cout << "c_device: ", c_g_m_n_device_result, ",")
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "d0_host: ", d0_g_m_host_result.mData, ",") LogRangeAsType<float>(std::cout << "d0_host: ", d0_g_m_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(std::cout << "d0_device: ", d0_g_m_device_result, ",")
std::cout << "d0_device: ", d0_g_m_device_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "d1_host: ", d1_g_m_host_result.mData, ",") LogRangeAsType<float>(std::cout << "d1_host: ", d1_g_m_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(std::cout << "d1_device: ", d1_g_m_device_result, ",")
std::cout << "d1_device: ", d1_g_m_device_result.mData, ",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -6,18 +6,19 @@ ...@@ -6,18 +6,19 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -104,21 +105,21 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -104,21 +105,21 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), Row>::value) if constexpr(std::is_same_v<decltype(layout), Row>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -137,10 +138,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -137,10 +138,10 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; std::cout << "c_g_m_o: " << c_g_m_o_host_result.GetDesc() << std::endl;
std::srand(1); // work around test flakiness std::srand(1); // work around test flakiness
switch(init_method) switch(init_method)
...@@ -174,14 +175,14 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -174,14 +175,14 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); DeviceMem c_g_m_o_device_buf(c_g_m_o_device_result.GetMemorySize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -240,29 +241,28 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -240,29 +241,28 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(a_g_m_k_device_buf.GetDeviceBuffer(),
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), b0_g_k_n_device_buf.GetDeviceBuffer(),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), b1_g_n_o_device_buf.GetDeviceBuffer(),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), c_g_m_o_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()), M,
M, N,
N, K,
K, O,
O, BatchCount,
BatchCount, StrideA,
StrideA, StrideB0,
StrideB0, StrideB1,
StrideB1, StrideC,
StrideC, BatchStrideA,
BatchStrideA, BatchStrideB0,
BatchStrideB0, BatchStrideB1,
BatchStrideB1, BatchStrideC,
BatchStrideC, a_element_op,
a_element_op, b0_element_op,
b0_element_op, acc0_element_op,
acc0_element_op, b1_element_op,
b1_element_op, c_element_op);
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -295,24 +295,20 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -295,24 +295,20 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.data());
pass = pass & pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result);
ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData);
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k, ",") << std::endl;
<< std::endl; LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o, ",") << std::endl;
<< std::endl;
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",")
<< std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result.mData, ",") std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",") std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result, ",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -4,19 +4,19 @@ ...@@ -4,19 +4,19 @@
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" #include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp" #include "ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -25,16 +25,16 @@ template <typename DataType> ...@@ -25,16 +25,16 @@ template <typename DataType>
void show_data_nhwc_layout(Tensor<DataType>& nhwc) void show_data_nhwc_layout(Tensor<DataType>& nhwc)
{ {
std::cout << "["; std::cout << "[";
for(int n = 0; n < ck::type_convert<int>(nhwc.mDesc.GetLengths()[0]); n++) for(int n = 0; n < ck::type_convert<int>(nhwc.GetLengths()[0]); n++)
{ {
std::cout << "["; std::cout << "[";
for(int hi = 0; hi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[2]); hi++) for(int hi = 0; hi < ck::type_convert<int>(nhwc.GetLengths()[2]); hi++)
{ {
std::cout << "["; std::cout << "[";
for(int wi = 0; wi < ck::type_convert<int>(nhwc.mDesc.GetLengths()[3]); wi++) for(int wi = 0; wi < ck::type_convert<int>(nhwc.GetLengths()[3]); wi++)
{ {
std::cout << "["; std::cout << "[";
for(int c = 0; c < ck::type_convert<int>(nhwc.mDesc.GetLengths()[1]); c++) for(int c = 0; c < ck::type_convert<int>(nhwc.GetLengths()[1]); c++)
{ {
std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " "; std::cout << static_cast<float>(nhwc(n, c, hi, wi)) << " ";
} }
...@@ -82,9 +82,9 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -82,9 +82,9 @@ bool profile_conv_bwd_data_impl(int do_verification,
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc); Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> output(out_g_n_k_wos_desc); Tensor<OutDataType> output(out_g_n_k_wos_desc);
std::cout << "input: " << input_host_result.mDesc << std::endl; std::cout << "input: " << input_host_result.GetDesc() << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "weight: " << weight.GetDesc() << std::endl;
std::cout << "output: " << output.mDesc << std::endl; std::cout << "output: " << output.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -98,12 +98,12 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -98,12 +98,12 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(input_device_result.GetMemorySize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(weight.GetMemorySize());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(output.GetMemorySize());
out_device_buf.ToDevice(output.mData.data()); out_device_buf.ToDevice(output.data());
wei_device_buf.ToDevice(weight.mData.data()); wei_device_buf.ToDevice(weight.data());
if(do_verification) if(do_verification)
{ {
...@@ -157,23 +157,22 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -157,23 +157,22 @@ bool profile_conv_bwd_data_impl(int do_verification,
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), wei_device_buf.GetDeviceBuffer(),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), out_device_buf.GetDeviceBuffer(),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), conv_param.N_,
conv_param.N_, conv_param.K_,
conv_param.K_, conv_param.C_,
conv_param.C_, conv_param.input_spatial_lengths_,
conv_param.input_spatial_lengths_, conv_param.filter_spatial_lengths_,
conv_param.filter_spatial_lengths_, conv_param.output_spatial_lengths_,
conv_param.output_spatial_lengths_, conv_param.conv_filter_strides_,
conv_param.conv_filter_strides_, conv_param.conv_filter_dilations_,
conv_param.conv_filter_dilations_, conv_param.input_left_pads_,
conv_param.input_left_pads_, conv_param.input_right_pads_,
conv_param.input_right_pads_, in_element_op,
in_element_op, wei_element_op,
wei_element_op, out_element_op);
out_element_op);
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -207,10 +206,9 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -207,10 +206,9 @@ bool profile_conv_bwd_data_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
in_device_buf.FromDevice(input_device_result.mData.data()); in_device_buf.FromDevice(input_device_result.data());
pass = pass = pass & ck::utils::check_err(input_device_result, input_host_result);
pass & ck::utils::check_err(input_device_result.mData, input_host_result.mData);
if(do_log) if(do_log)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment