"vscode:/vscode.git/clone" did not exist on "8f6fc266cde284431dfdf5bc1b1de212356eb31a"
Commit 7b01dbee authored by ltqin's avatar ltqin
Browse files

change disired grid size to parameters

parent cca0ceee
......@@ -28,14 +28,14 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>
// clang-format on
>;
......
......@@ -28,14 +28,14 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>
// clang-format on
>;
......
......@@ -28,15 +28,15 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 1, 3, 4>, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 8>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 1, 3, 4>, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 8>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true>
>;
template <>
......
......@@ -28,19 +28,19 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true>
// clang-format on
>;
......
......@@ -25,7 +25,8 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
CElementwiseOperation c_element_op,
ck::index_t desired_gride_size = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -56,8 +56,7 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN,
ck::index_t DesiredGridSize>
bool BBlockLdsAddExtraN>
struct DeviceGemmSplitKXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
......@@ -147,7 +146,7 @@ struct DeviceGemmSplitKXdl
}
}
static auto GetKBatchAndKPad(index_t M, index_t N, index_t K)
static auto GetKBatchAndKPad(index_t M, index_t N, index_t K, index_t DesiredGridSize)
{
const auto GridMN = M * N / (MPerBlock * NPerBlock);
const index_t KBatch = std::max(DesiredGridSize / GridMN, 1);
......@@ -264,7 +263,8 @@ struct DeviceGemmSplitKXdl
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
index_t desired_grid_Size)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -277,10 +277,12 @@ struct DeviceGemmSplitKXdl
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
desired_grid_size_{desired_grid_Size}
{
int KBatch = 1, KPad = K;
std::tie(KBatch, KPad) = DeviceGemmSplitKXdl::GetKBatchAndKPad(M, N, K);
std::tie(KBatch, KPad) =
DeviceGemmSplitKXdl::GetKBatchAndKPad(M, N, K, desired_grid_size_);
a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmSplitKXdl::MakeAGridDescriptor_KBatch_K0_M_K1(
M, K, StrideA, KBatch, KPad);
......@@ -316,6 +318,7 @@ struct DeviceGemmSplitKXdl
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t desired_grid_size_;
};
// Invoker
......@@ -524,7 +527,8 @@ struct DeviceGemmSplitKXdl
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
index_t desired_grid_Size)
{
return Argument{p_a,
p_b,
......@@ -539,7 +543,8 @@ struct DeviceGemmSplitKXdl
1,
a_element_op,
b_element_op,
c_element_op};
c_element_op,
desired_grid_Size};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -556,7 +561,8 @@ struct DeviceGemmSplitKXdl
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op,
ck::index_t desired_gride_size = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -571,7 +577,8 @@ struct DeviceGemmSplitKXdl
1,
a_element_op,
b_element_op,
c_element_op);
c_element_op,
desired_gride_size);
}
// polymorphic
......
......@@ -408,7 +408,8 @@ struct DeviceGemmXdl
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op,
ck::index_t) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......
......@@ -11,7 +11,8 @@ void add_device_splitk_gemm_instance<float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
......@@ -19,7 +20,8 @@ void add_device_splitk_gemm_instance<float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
......@@ -27,7 +29,8 @@ void add_device_splitk_gemm_instance<float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_splitk_gemm_instance<float,
......@@ -35,8 +38,8 @@ void add_device_splitk_gemm_instance<float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
ck::tensor_layout::gemm::RowMajor>(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
......
......@@ -6,10 +6,6 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
template <>
void add_device_gemm_instance<float,
float,
......
......@@ -27,7 +27,8 @@ using DeviceGemmNoOpPtr =
ck::tensor_operation::element_wise::PassThrough>;
using GEMM_PTR = std::vector<DeviceGemmNoOpPtr>;
static std::vector<std::vector<bool>>& GetLayoutType(){
static std::vector<std::vector<bool>>& GetLayoutType()
{
static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
return LayOut;
}
......@@ -74,14 +75,14 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
static auto& GetAddDeviceGemmInstance()
{
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = {add_device_gemm_instance_mk_kn_mn,
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = {
add_device_gemm_instance_mk_kn_mn,
add_device_gemm_instance_mk_nk_mn,
add_device_gemm_instance_km_kn_mn,
add_device_gemm_instance_km_nk_mn};
return AddDeviceGemmInstance;
}
static void add_device_gemm_instance(GEMM_PTR& gemm_ptrs, int layout)
{
GetAddDeviceGemmInstance()[layout](gemm_ptrs);
......@@ -105,13 +106,13 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
}
int main(int argc, char* argv[])
{
if(argc != 8)
if(argc != 9)
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, n] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, n] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC DesiredGridSize\n");
return 1;
}
......@@ -124,6 +125,7 @@ int main(int argc, char* argv[])
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int DesiredGridSize = std::stoi(argv[8]);
if(layout > 3 || layout < 0)
{
......@@ -191,7 +193,8 @@ int main(int argc, char* argv[])
StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
ck::tensor_operation::element_wise::PassThrough{},
DesiredGridSize);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment