Commit b3a4d179 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed output

parent 9bdad55b
...@@ -113,7 +113,8 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -113,7 +113,8 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(4, 2, 4)), make_pass_through_transform(N)), make_tuple(make_unmerge_transform(make_tuple(GemmM / 8, 2, 4)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
......
...@@ -476,11 +476,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -476,11 +476,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory // output: register to global memory
{ {
StaticBuffer<AddressSpace::Vgpr, float, 64> c_thread_buf_;
static_for<0, 64, 1>{}(
[&](auto i) { c_thread_buf_(i) = c_thread_buf.template AsType<float>()[i]; });
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout(); constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1(); constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1(); constexpr index_t K1 = OutputLayout.N1();
...@@ -498,8 +493,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -498,8 +493,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
static_assert(BlkSize == 16 && NumBlks == 4, ""); static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches // force unrolling the output loop to get ride of scratches
for(index_t i = 0; i < NumBlks; ++i) static_for<0, NumBlks, 1>{}([&](auto i) {
{ StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) =
c_thread_buf.template AsType<float>()[Number<i * BlkSize + j>{}];
});
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -535,7 +536,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -535,7 +536,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
c_m0_m1_m2_n_global_desc, c_m0_m1_m2_n_global_desc,
c_global_buf, c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
} });
} }
} }
......
...@@ -688,7 +688,7 @@ int main(int argc, char* argv[]) ...@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment