"examples/vscode:/vscode.git/clone" did not exist on "4915524fa189651a1ab08b44690cc0cb8b772282"
Unverified Commit b51808d7 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Fix conv2d bwd data bug when filter is 1x1 and stride = 2 (#132)



* fix bwd data filter1strid2 bug

* fichangeshort to ck::bhalf_t

* reset input to zero
Co-authored-by: default avatarltqin <letaoqin@amd.com>
parent 9a17e7fb
......@@ -180,6 +180,10 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// do GEMM
auto conv = DeviceConvBwdDataInstance{};
auto invoker = conv.MakeInvoker();
......
......@@ -459,6 +459,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
{
// check slice is valid
const index_t Y = filter_spatial_lengths_[0];
const index_t X = filter_spatial_lengths_[1];
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
N,
K,
......
......@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace device_conv2d_bwd_data_instance {
using BF16 = ushort;
using BF16 = ck::bhalf_t;
using F32 = float;
template <ck::index_t... Is>
......
......@@ -11,7 +11,7 @@
using F16 = ck::half_t;
using F32 = float;
using BF16 = ushort;
using BF16 = ck::bhalf_t;
using INT8 = int8_t;
namespace ck {
namespace tensor_operation {
......@@ -172,9 +172,9 @@ void profile_conv_bwd_data_impl(int do_verification,
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
......
......@@ -182,8 +182,8 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{5});
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// get host result
......@@ -225,9 +225,9 @@ int main(int argc, char* argv[])
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ushort> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ushort>)
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
{
ck::tensor_operation::device::device_conv2d_bwd_data_instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment