"vscode:/vscode.git/clone" did not exist on "cf084d8bd3f2f0b57f12e1c1788567372a09d481"
Commit 0bf754ec authored by ltqin's avatar ltqin
Browse files

modify transformat

parent 59ac6f84
......@@ -62,23 +62,16 @@ transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmN = C * Y * X;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
......@@ -104,7 +97,7 @@ transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
......@@ -114,15 +107,22 @@ transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment