Commit 2807c69e authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed tensor init

parent 05ab9105
...@@ -139,7 +139,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -139,7 +139,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
case 0: case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{0x99}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break; break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
...@@ -158,9 +158,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -158,9 +158,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
b_k_n(0, 0) = 0xaa;
b_k_n(1, 1) = 0xaa;
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......
...@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t> ...@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
} }
}; };
template <>
struct GeneratorTensor_1<ck::pk_i4_t>
{
int8_t value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
}
};
template <typename T> template <typename T>
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
...@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t> ...@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
template <>
struct GeneratorTensor_2<ck::pk_i4_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
};
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
template <> template <>
struct GeneratorTensor_2<ck::f8_t> struct GeneratorTensor_2<ck::f8_t>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment