Unverified Commit 1c8126a4 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

add batch_stride into batched gemm (#314)



* add batch_stride

* fixed test
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0dcb3496
...@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator ...@@ -32,6 +32,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
......
...@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -341,6 +341,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -357,10 +360,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{ block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01}, M01_{M01},
...@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -543,6 +543,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -557,6 +560,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1, 1,
1, 1,
a_element_op, a_element_op,
...@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -577,6 +583,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -591,6 +600,9 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1, 1,
1, 1,
a_element_op, a_element_op,
......
...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M, int M,
int N, int N,
int K, int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -150,6 +156,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
......
...@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M; const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
bool pass = ck::profiler:: bool pass = ck::profiler::
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>( profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification, do_verification,
...@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M, M,
N, N,
K, K,
(StrideA < 0) ? DefaultStrideA : StrideA, BatchStrideA,
(StrideB < 0) ? DefaultStrideB : StrideB, BatchStrideB,
(StrideC < 0) ? DefaultStrideC : StrideC, BatchStrideC,
StrideA_,
StrideB_,
StrideC_,
BatchCount); BatchCount);
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -25,19 +25,19 @@ int main() ...@@ -25,19 +25,19 @@ int main()
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, BatchCount); true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, BatchCount); true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, BatchCount); true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, BatchCount); true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1; return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment