Commit 2d194c52 authored by ltqin's avatar ltqin
Browse files

input hack

parent 252d271c
...@@ -51,7 +51,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -51,7 +51,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
...@@ -59,10 +59,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -59,10 +59,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
...@@ -98,12 +98,12 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -98,12 +98,12 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
constexpr auto wei_m0_m1_m2_n_grid_step_hacks = constexpr auto wei_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
...@@ -127,7 +127,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -127,7 +127,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
...@@ -158,9 +158,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -158,9 +158,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
1, Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
......
...@@ -201,8 +201,8 @@ int main(int argc, char* argv[]) ...@@ -201,8 +201,8 @@ int main(int argc, char* argv[])
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break; break;
case 5: case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread); in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 0.01}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread); out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 0.01}, num_thread);
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment