Unverified Commit 6f0564f0 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Rangify FillUniformDistributionIntegerValue<> (#443)

Allow passing forward range to its call operator
parent 70456328
......@@ -32,14 +32,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......
......@@ -77,15 +77,12 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input.begin(),
conv_input.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight.begin(),
conv_weight.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input.begin(), conv_input.end());
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight.begin(),
conv_weight.end());
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input);
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight);
}
DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize());
......
......@@ -160,14 +160,12 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M,
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
}
......
......@@ -134,14 +134,12 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
}
......@@ -339,14 +337,12 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
}
......
......@@ -100,9 +100,9 @@ int main(int argc, char* argv[])
Tensor<GammaDataType> gamma({G, C});
Tensor<BetaDataType> beta({G, C});
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x.begin(), x.end());
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma.begin(), gamma.end());
ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta.begin(), beta.end());
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x);
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma);
ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta);
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
......
......@@ -30,9 +30,10 @@ struct FillUniformDistribution
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) -> std::void_t<decltype(
std::declval<FillUniformDistribution>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
......@@ -72,6 +73,16 @@ struct FillUniformDistributionIntegerValue
std::generate(
first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); });
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillUniformDistributionIntegerValue&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment