Commit d8a632a8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed bfloat16 issues

parent 89e1ebd4
......@@ -927,7 +927,7 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
__host__ __device__ float bf16_to_f32(ushort src_val)
static __host__ __device__ float bf16_to_f32(ushort src_val)
{
union
{
......@@ -937,7 +937,7 @@ __host__ __device__ float bf16_to_f32(ushort src_val)
return u.fp32;
}
__host__ __device__ ushort f32_to_bf16(float src_val)
static __host__ __device__ ushort f32_to_bf16(float src_val)
{
union
{
......
#pragma once
#include "host_tensor.hpp"
template <>
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
const Tensor<ushort>& b,
Tensor<ushort>& c,
const GemmMatrixLayout layout)
{
if(layout == GemmMatrixLayout::MK_KN_MN)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
{
auto f_mk_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
{
auto f_km_kn_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_MN)
{
auto f_km_nk_mn = [&](auto m, auto n) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(m, n) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_KN_NM)
{
auto f_mk_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::MK_NK_NM)
{
auto f_mk_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_KN_NM)
{
auto f_km_kn_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else if(layout == GemmMatrixLayout::KM_NK_NM)
{
auto f_km_nk_nm = [&](auto n, auto m) {
const int K = a.mDesc.GetLengths()[0];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
}
c(n, m) = ck::f32_to_bf16(v);
};
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
template <typename AType, typename BType, typename CType>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
......
......@@ -299,53 +299,41 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> s
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
float bf16_to_f32_(ushort src_val);
template <typename T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
if constexpr(std::is_same<ushort, T>::value)
{
for(int i = 0; i < ref.mData.size(); ++i)
{
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i]));
if(max_diff < diff)
{
max_diff = diff;
ref_value = ref.mData[i];
result_value = result.mData[i];
ref_value = bf16_to_f32_(ref.mData[i]);
result_value = bf16_to_f32_(result.mData[i]);
}
}
std::cout << "error: " << error << std::endl;
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}
__host__ __device__ float bf16_to_f32(ushort src_val)
{
union
}
else
{
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
template <>
void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result)
{
float error = 0;
float max_diff = -1;
float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i)
{
error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
max_diff = diff;
ref_value = bf16_to_f32(ref.mData[i]);
result_value = bf16_to_f32(result.mData[i]);
ref_value = ref.mData[i];
result_value = result.mData[i];
}
}
}
......
......@@ -61,3 +61,13 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream
LogRange(os, desc.GetStrides(), ", ");
os << "}" << std::endl;
}
float bf16_to_f32_(ushort src_val)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(src_val) << 16};
return u.fp32;
}
......@@ -106,12 +106,12 @@ void profile_conv(int do_verification,
{
case 0: break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
if(do_verification)
......
......@@ -122,12 +122,12 @@ void profile_gemm(int do_verification,
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment