Commit 486e7129 authored by rocking's avatar rocking
Browse files

Refine gemm dlops int8 kernel parameter

parent b8966f7a
...@@ -27,9 +27,9 @@ using I32 = int32_t; ...@@ -27,9 +27,9 @@ using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using ActivationOp = PassThrough; using ActivationOp = PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>; using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
...@@ -65,26 +65,26 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl< ...@@ -65,26 +65,26 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
16, // K0PerBlock 16, // K0PerBlock
2, // K1 4, // K1
4, // M1PerThread 4, // M1PerThread
4, // N1PerThread 4, // N1PerThread
1, // KPerThread 1, // KPerThread
S<8, 2>, // M1N1ThreadClusterM1Xs S<8, 2>, // M1N1ThreadClusterM1Xs
S<8, 2>, // M1N1ThreadClusterN1Xs S<8, 2>, // M1N1ThreadClusterN1Xs
S<8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
S<4, 1, 1, 2>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 2>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<2, 1, 4, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<8, 1, 32, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
S<0, 3, 1, 2>, // BBlockTransferSrcAccessOrder S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
S<1, 1, 4, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<0, 3, 1, 2>, // BBlockTransferSrcVectorTensorContiguousDimOrder S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 4, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim 5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector 4>; // CThreadTransferDstScalarPerVector
...@@ -133,8 +133,8 @@ int main() ...@@ -133,8 +133,8 @@ int main()
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-128, 127}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-128, 127}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
...@@ -143,8 +143,8 @@ int main() ...@@ -143,8 +143,8 @@ int main()
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = PassThrough{}; auto a_element_op = AElementOp{};
auto b_element_op = PassThrough{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
// do GEMM // do GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment