Commit 1475cb44 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add tensor generators

parent 35a32da2
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t> ...@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_1<ck::f4_t>
{
float value = 1.0;
template <typename... Is>
ck::bhalf_t operator()(Is...)
{
return ck::type_convert<ck::f4_t>(value);
}
};
template <> template <>
struct GeneratorTensor_1<int8_t> struct GeneratorTensor_1<int8_t>
{ {
...@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t> ...@@ -153,6 +165,20 @@ struct GeneratorTensor_2<ck::bf8_t>
}; };
#endif #endif
template <>
struct GeneratorTensor_2<ck::f4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f4_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f4_t>(tmp);
}
};
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -223,6 +249,22 @@ struct GeneratorTensor_3<ck::bf8_t> ...@@ -223,6 +249,22 @@ struct GeneratorTensor_3<ck::bf8_t>
}; };
#endif #endif
struct GeneratorTensor_3<ck::f4_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f4_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f4_t>(fp32_tmp);
}
};
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment