Commit d8a632a8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed bfloat16 issues

parent 89e1ebd4
...@@ -927,7 +927,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -927,7 +927,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
__host__ __device__ float bf16_to_f32(ushort src_val) static __host__ __device__ float bf16_to_f32(ushort src_val)
{ {
union union
{ {
...@@ -937,7 +937,7 @@ __host__ __device__ float bf16_to_f32(ushort src_val) ...@@ -937,7 +937,7 @@ __host__ __device__ float bf16_to_f32(ushort src_val)
return u.fp32; return u.fp32;
} }
__host__ __device__ ushort f32_to_bf16(float src_val) static __host__ __device__ ushort f32_to_bf16(float src_val)
{ {
union union
{ {
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <>
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
const Tensor<ushort>& b,
Tensor<ushort>& c,
const GemmMatrixLayout layout)
{
if(layout == GemmMatrixLayout::MK_KN_MN)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
{
auto f_mk_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
{
auto f_km_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
auto f_km_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
template <typename AType, typename BType, typename CType> template <typename AType, typename BType, typename CType>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n, const Tensor<BType>& b_k_n,
......
...@@ -299,53 +299,41 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s ...@@ -299,53 +299,41 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
float bf16_to_f32_(ushort src_val);
template <typename T> template <typename T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result) void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{ {
float error = 0; float error = 0;
float max_diff = -1; float max_diff = -1;
float ref_value = 0, result_value = 0; float ref_value = 0, result_value = 0;
if constexpr(std::is_same<ushort, T>::value)
{
for(int i = 0; i < ref.mData.size(); ++i) for(int i = 0; i < ref.mData.size(); ++i)
{ {
error += std::abs(double(ref.mData[i]) - double(result.mData[i])); error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
if(max_diff < diff) if(max_diff < diff)
{ {
max_diff = diff; max_diff = diff;
ref_value = ref.mData[i]; ref_value = bf16_to_f32_(ref.mData[i]);
result_value = result.mData[i]; result_value = bf16_to_f32_(result.mData[i]);
} }
} }
}
std::cout << "error: " << error << std::endl; else
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}
__host__ __device__ float bf16_to_f32(ushort src_val)
{
union
{ {
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
template <>
void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result)
{
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i) for(int i = 0; i < ref.mData.size(); ++i)
{ {
error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i])); float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff) if(max_diff < diff)
{ {
max_diff = diff; max_diff = diff;
ref_value = bf16_to_f32(ref.mData[i]); ref_value = ref.mData[i];
result_value = bf16_to_f32(result.mData[i]); result_value = result.mData[i];
}
} }
} }
......
...@@ -61,3 +61,13 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream ...@@ -61,3 +61,13 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream
LogRange(os, desc.GetStrides(), ", "); LogRange(os, desc.GetStrides(), ", ");
os << "}" << std::endl; os << "}" << std::endl;
} }
float bf16_to_f32_(ushort src_val)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
...@@ -106,12 +106,12 @@ void profile_conv(int do_verification, ...@@ -106,12 +106,12 @@ void profile_conv(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
if(do_verification) if(do_verification)
......
...@@ -122,12 +122,12 @@ void profile_gemm(int do_verification, ...@@ -122,12 +122,12 @@ void profile_gemm(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
if(do_verification) if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment