Unverified Commit 9c3ce77c authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Shrink init value range for FA examples (#23)

parent 95889861
......@@ -71,13 +71,13 @@ int main(int argc, char* argv[])
Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
#else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
ck::utils::FillUniformDistribution<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-2.f, 2.f}(v_host);
#endif
// reference
......
......@@ -109,13 +109,13 @@ int main(int argc, char* argv[])
Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
#if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
#else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
ck::utils::FillUniformDistribution<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-2.f, 2.f}(v_host);
#endif
DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
......
......@@ -69,13 +69,13 @@ int main(int argc, char* argv[])
Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host);
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
#else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
ck::utils::FillUniformDistribution<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-2.f, 2.f}(v_host);
#endif
// reference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment