Unverified Commit 9c3ce77c authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Shrink init value range for FA examples (#23)

parent 95889861
...@@ -71,13 +71,13 @@ int main(int argc, char* argv[]) ...@@ -71,13 +71,13 @@ int main(int argc, char* argv[])
Tensor<ODataType> o_host_dev(o_lengths, o_strides); Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0 #if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host); ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
#else #else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistribution<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host); ck::utils::FillUniformDistribution<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistribution<VDataType>{-2.f, 2.f}(v_host);
#endif #endif
// reference // reference
......
...@@ -109,13 +109,13 @@ int main(int argc, char* argv[]) ...@@ -109,13 +109,13 @@ int main(int argc, char* argv[])
Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
#if 0 #if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host); ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
#else #else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistribution<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host); ck::utils::FillUniformDistribution<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistribution<VDataType>{-2.f, 2.f}(v_host);
#endif #endif
DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize()); DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
......
...@@ -69,13 +69,13 @@ int main(int argc, char* argv[]) ...@@ -69,13 +69,13 @@ int main(int argc, char* argv[])
Tensor<ODataType> o_host_dev(o_lengths, o_strides); Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0 #if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f}(k_host); ck::utils::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f}(v_host);
#else #else
ck::utils::FillUniformDistribution<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistribution<QDataType>{-2.f, 2.f}(q_host);
ck::utils::FillUniformDistribution<KDataType>{-3.f, 3.f}(k_host); ck::utils::FillUniformDistribution<KDataType>{-2.f, 2.f}(k_host);
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistribution<VDataType>{-2.f, 2.f}(v_host);
#endif #endif
// reference // reference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment