Commit f2948084 authored by mtgu0705's avatar mtgu0705
Browse files

Added two kernel for M=32 problem

parent 2944c508
...@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>; ...@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using FP8 = ck::f8_t; using FP8 = ck::f8_t;
using F32 = float; using F32 = float;
using BF16 = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -38,7 +39,7 @@ using CShuffleDataType = F32; ...@@ -38,7 +39,7 @@ using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using D1DataType = F32; using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16; using EDataType = BF16;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Col;
...@@ -47,21 +48,23 @@ using D1Layout = Col; ...@@ -47,21 +48,23 @@ using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>; using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row; using ELayout = Row;
struct MultiplyMultiply // struct MultiplyMultiply
{ // {
template <typename E, typename C, typename D0, typename D1> // template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void // __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const; // operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <> // template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>( // __host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const // ck::half_t& e, const float& c, const float& d0, const float& d1) const
{ // {
const float x0_f = c * d0 * d1; // const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::half_t>(x0_f); // e = ck::type_convert<ck::bhalf_t>(x0_f);
} // }
}; // };
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -80,7 +83,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -80,7 +83,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///###### RRR ///###### RRR
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
///###### RCR ///###### RCR
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // kernel 1: 256->32x128x128
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// kernel 2: 128->32x128x128
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment