Commit 244e313f authored by letaoqin's avatar letaoqin
Browse files

tuning

parent 0435c336
...@@ -40,17 +40,17 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul ...@@ -40,17 +40,17 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul
ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
64, 128,
16, 16, 64, 32, 64, 64,
8, 8, 8, 4,
16, 16, 32, 32,
1, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 2, 2, 0,
1, 1, 1, 1,
S<1, 16, 1, 4>, S<4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 4, 0,
1, 1,
S<1, 16, 1, 8>, S<8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, F16>;
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType, typename CDEElementOp> template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType, typename CDEElementOp>
using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
......
...@@ -240,13 +240,14 @@ int main(int argc, char* argv[]) ...@@ -240,13 +240,14 @@ int main(int argc, char* argv[])
float ave_time = 0; float ave_time = 0;
if(op_type == 0) if(op_type == 0)
gemm_bias_add_gelu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); ave_time = gemm_bias_add_gelu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50});
else if(op_type == 1) else if(op_type == 1)
gemm_bias_add_relu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); ave_time = gemm_bias_add_relu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50});
else if(op_type == 2) else if(op_type == 2)
gemm_bias_add_silu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); ave_time = gemm_bias_add_silu_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50});
else else
gemm_bias_add_sigmoid_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50}); ave_time =
gemm_bias_add_sigmoid_fp16(gemm_args, StreamConfig{nullptr, time_kernel, 20, 50});
// float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); // float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment