"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "20ea6c759e27b8794cc187fe683ca77051b02e51"
Commit 0ade7981 authored by Jing Zhang's avatar Jing Zhang
Browse files

add mnk padding

parent ab04f22f
...@@ -39,7 +39,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -39,7 +39,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl
...@@ -47,7 +49,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermu ...@@ -47,7 +49,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermu
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...@@ -59,13 +62,9 @@ int main(int argc, char* argv[]) ...@@ -59,13 +62,9 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
// const int M = 88; const int M = 88;
// const int N = 64; const int N = 64;
// const int K = 88; const int K = 88;
const int M = 256;
const int N = 128;
const int K = 64;
const int stride_A = K; const int stride_A = K;
const int stride_B = K; const int stride_B = K;
...@@ -76,8 +75,8 @@ int main(int argc, char* argv[]) ...@@ -76,8 +75,8 @@ int main(int argc, char* argv[])
const int batch_count = G0 * G1; const int batch_count = G0 * G1;
// output layout - [G0, M, G1, N] // output layout - [G0, M, G1, N]
const int stride_B0 = M * G1 * N; const int stride_G0 = M * G1 * N;
const int stride_B1 = N; const int stride_G1 = N;
const int stride_M = G1 * N; const int stride_M = G1 * N;
const int stride_N = 1; const int stride_N = 1;
...@@ -97,7 +96,7 @@ int main(int argc, char* argv[]) ...@@ -97,7 +96,7 @@ int main(int argc, char* argv[])
// GEMM shape // GEMM shape
ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{ ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{
G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N}; G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_, auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row, std::size_t row,
...@@ -119,24 +118,24 @@ int main(int argc, char* argv[]) ...@@ -119,24 +118,24 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t B0_, auto f_host_c_tensor_descriptor = [](std::size_t G0_,
std::size_t B1_, std::size_t G1_,
std::size_t M_, std::size_t M_,
std::size_t N_, std::size_t N_,
std::size_t stride_B0_, std::size_t stride_G0_,
std::size_t stride_B1_, std::size_t stride_G1_,
std::size_t stride_M_, std::size_t stride_M_,
std::size_t stride_N_) { std::size_t stride_N_) {
return HostTensorDescriptor( return HostTensorDescriptor(
std::vector<std::size_t>({B0_, B1_, M_, N_}), std::vector<std::size_t>({G0_, G1_, M_, N_}),
std::vector<std::size_t>({stride_B0_, stride_B1_, stride_M_, stride_N_})); std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
}; };
Tensor<CDataType> c_g0_g1_m_n_host_result( Tensor<CDataType> c_g0_g1_m_n_host_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N)); f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<CDataType> c_g0_g1_m_n_device_result( Tensor<CDataType> c_g0_g1_m_n_device_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N)); f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment