Unverified Commit 1c8126a4 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

add batch_stride into batched gemm (#314)



* add batch_stride

* fixed test
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0dcb3496
......@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
......@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1,
1,
a_element_op,
......@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1,
1,
a_element_op,
......
......@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M,
int N,
int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA,
int StrideB,
int StrideC,
......@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1}));
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride}));
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
......@@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
......
......@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
bool pass = ck::profiler::
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification,
......@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
StrideA_,
StrideB_,
StrideC_,
BatchCount);
return pass ? 0 : 1;
......
......@@ -25,19 +25,19 @@ int main()
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, BatchCount);
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, BatchCount);
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, BatchCount);
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, BatchCount);
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment