"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "dc86bd421ed98777112c64f61940321631c11806"
Unverified Commit a651ea4f authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Fixed bfp16 host_conv_fwd (#52)



* fixed bfloat16 issues

* refactor type_convert

* fixed host_convolution_forward for ushort
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0a66c54e
...@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if constexpr(is_same<TOut, ushort>::value) if constexpr(is_same<TOut, ushort>::value)
{ {
out(n, k, ho, wo) = type_convert<ushort>(v); out(n, k, ho, wo) = ck::type_convert<ushort>(static_cast<float>(v));
} }
else else
{ {
...@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in, ...@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
} }
if constexpr(is_same<TOut, ushort>::value) if constexpr(is_same<TOut, ushort>::value)
{ {
out(n, ho, wo, k) = ck::type_convert<ushort>(v); out(n, ho, wo, k) = ck::type_convert<ushort>(static_cast<float>(v));
} }
else else
{ {
...@@ -257,7 +257,7 @@ int main(int argc, char* argv[]) ...@@ -257,7 +257,7 @@ int main(int argc, char* argv[])
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 0 #elif 1
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
......
...@@ -239,14 +239,10 @@ int main(int argc, char* argv[]) ...@@ -239,14 +239,10 @@ int main(int argc, char* argv[])
using ab_data_t = float; using ab_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using c_data_t = float; using c_data_t = float;
#elif 0 #elif 1
using ab_data_t = half_t; using ab_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using c_data_t = half_t; using c_data_t = half_t;
#elif 1
using ab_data_t = ushort;
using acc_data_t = float;
using c_data_t = ushort;
#elif 1 #elif 1
using ab_data_t = int8_t; using ab_data_t = int8_t;
using acc_data_t = int32_t; using acc_data_t = int32_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment