Commit 74603261 authored by Chao Liu's avatar Chao Liu
Browse files

fix initialization issue

parent 360184cd
......@@ -197,8 +197,8 @@ int run_conv_bwd_data(bool do_verification,
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace());
......
......@@ -6,8 +6,8 @@
#include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp"
using InDataType = ck::bhalf_t;
using WeiDataType =
float; // bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
using WeiDataType = float;
using OutDataType = ck::bhalf_t;
using AccDataType = float;
......
......@@ -154,8 +154,8 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
weight.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
......
......@@ -156,12 +156,12 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{
case 0: break;
case 1:
input.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
output.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
input.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
output.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
......
......@@ -197,7 +197,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::BF16_F32_BF16)
{
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{});
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, F32{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
......@@ -212,7 +213,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::BF16_F32_BF16)
{
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{});
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, F32{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
......@@ -227,7 +229,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::BF16_F32_BF16)
{
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{});
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, F32{}, BF16{});
}
}
......
......@@ -196,6 +196,6 @@ int main()
else
{
std::cout << "test convnd bwd: Fail " << std::endl;
return -1;
return 1;
}
}
add_test_executable(test_convnd_bwd_weight convnd_bwd_weight.cpp)
target_link_libraries(test_convnd_bwd_weight PRIVATE utility device_convnd_bwd_weight_instance)
target_link_libraries(test_convnd_bwd_weight PRIVATE utility device_conv1d_bwd_weight_instance device_conv2d_bwd_weight_instance device_conv3d_bwd_weight_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment