Commit b3a4d179 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed output

parent 9bdad55b
......@@ -113,7 +113,8 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(4, 2, 4)), make_pass_through_transform(N)),
make_tuple(make_unmerge_transform(make_tuple(GemmM / 8, 2, 4)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
......
......@@ -476,11 +476,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory
{
StaticBuffer<AddressSpace::Vgpr, float, 64> c_thread_buf_;
static_for<0, 64, 1>{}(
[&](auto i) { c_thread_buf_(i) = c_thread_buf.template AsType<float>()[i]; });
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
......@@ -498,8 +493,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches
for(index_t i = 0; i < NumBlks; ++i)
{
static_for<0, NumBlks, 1>{}([&](auto i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) =
c_thread_buf.template AsType<float>()[Number<i * BlkSize + j>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
......@@ -535,7 +536,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
});
}
}
......
......@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment