Commit 3b22355a authored by letaoqin's avatar letaoqin
Browse files

add a tuning instance

parent 12c3c3c0
......@@ -36,7 +36,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
// clang-format off
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType, typename CDEElementOp>
using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
using DeviceOpInstance_128_32_64_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
......@@ -52,6 +52,24 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul
1, 1,
S<1, 16, 1, 8>, S<8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, F16>;
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType, typename CDEElementOp>
using DeviceOpInstance_256_128_128_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256,
128, 128, 64,
8, 4,
32, 32,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 4, 0,
1, 1,
S<1, 32, 1, 8>, S<8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>;
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType, typename CDEElementOp>
using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType,
......@@ -128,14 +146,32 @@ float run_impl(const GemmBiasAddArgs& args, const StreamConfig& config)
return true;
};
auto gemm =
DeviceOpInstance_64_16_16_64<ADataType, BDataType, DsDataType, CDataType, CDEElementOp>{};
if(!Run(gemm))
do
{
if(args.M <= 512)
{
auto gemm = DeviceOpInstance_128_32_64_64<ADataType,
BDataType,
DsDataType,
CDataType,
CDEElementOp>{};
if(Run(gemm))
break;
}
else
{
auto gemm = DeviceOpInstance_256_128_128_64<ADataType,
BDataType,
DsDataType,
CDataType,
CDEElementOp>{};
if(Run(gemm))
break;
}
auto gemm_def =
DeviceOpInstance_default<ADataType, BDataType, DsDataType, CDataType, CDEElementOp>{};
Run(gemm_def);
}
} while(0);
return ave_time;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment