Commit 022ce136 authored by ltqin's avatar ltqin
Browse files

generator add bfloat16_t

parent 79ba5b99
......@@ -117,6 +117,53 @@ check_err(const Range& out,
return res;
}
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bfloat16_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
......@@ -166,7 +213,8 @@ check_err(const Range& out,
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>&&
!std::is_same_v<ranges::range_value_t<Range>, bfloat16_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif
......
......@@ -43,6 +43,19 @@ struct GeneratorTensor_1<ck::bhalf_t>
}
};
template <>
struct GeneratorTensor_1<ck::bfloat16_t>
{
float value = 1.0;
template <typename... Is>
ck::bfloat16_t operator()(Is...)
{
return ck::type_convert<ck::bfloat16_t>(value);
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
......@@ -82,6 +95,20 @@ struct GeneratorTensor_2<ck::bhalf_t>
}
};
template <>
struct GeneratorTensor_2<ck::bfloat16_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::bfloat16_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::bfloat16_t>(tmp);
}
};
template <>
struct GeneratorTensor_2<int8_t>
{
......@@ -127,6 +154,23 @@ struct GeneratorTensor_3<ck::bhalf_t>
}
};
template <>
struct GeneratorTensor_3<ck::bfloat16_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::bfloat16_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::bfloat16_t>(fp32_tmp);
}
};
template <typename T>
struct GeneratorTensor_4
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment