Commit 74603261 authored by Chao Liu's avatar Chao Liu
Browse files

fix initialization issue

parent 360184cd
...@@ -197,8 +197,8 @@ int run_conv_bwd_data(bool do_verification, ...@@ -197,8 +197,8 @@ int run_conv_bwd_data(bool do_verification,
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace());
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp"
using InDataType = ck::bhalf_t; using InDataType = ck::bhalf_t;
using WeiDataType = // bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
float; // bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory using WeiDataType = float;
using OutDataType = ck::bhalf_t; using OutDataType = ck::bhalf_t;
using AccDataType = float; using AccDataType = float;
......
...@@ -154,8 +154,8 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -154,8 +154,8 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
weight.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
......
...@@ -156,12 +156,12 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -156,12 +156,12 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
input.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2}); input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
output.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2}); output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
input.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
output.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
......
...@@ -197,7 +197,8 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -197,7 +197,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
...@@ -212,7 +213,8 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -212,7 +213,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, F32{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
...@@ -227,7 +229,8 @@ int profile_conv_bwd_weight(int argc, char* argv[]) ...@@ -227,7 +229,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
} }
else if(data_type == ConvDataType::BF16_F32_BF16) else if(data_type == ConvDataType::BF16_F32_BF16)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, F32{}, BF16{});
} }
} }
......
...@@ -196,6 +196,6 @@ int main() ...@@ -196,6 +196,6 @@ int main()
else else
{ {
std::cout << "test convnd bwd: Fail " << std::endl; std::cout << "test convnd bwd: Fail " << std::endl;
return -1; return 1;
} }
} }
add_test_executable(test_convnd_bwd_weight convnd_bwd_weight.cpp) add_test_executable(test_convnd_bwd_weight convnd_bwd_weight.cpp)
target_link_libraries(test_convnd_bwd_weight PRIVATE utility device_convnd_bwd_weight_instance) target_link_libraries(test_convnd_bwd_weight PRIVATE utility device_conv1d_bwd_weight_instance device_conv2d_bwd_weight_instance device_conv3d_bwd_weight_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment