Commit 6adf3591 authored by Jing Zhang's avatar Jing Zhang
Browse files

add batch_stride

parent fa9a0a5c
...@@ -26,6 +26,9 @@ struct DeviceBatchedGemm : public BaseOperator ...@@ -26,6 +26,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
......
...@@ -334,6 +334,9 @@ struct DeviceBatchedGemmXdl ...@@ -334,6 +334,9 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -350,10 +353,7 @@ struct DeviceBatchedGemmXdl ...@@ -350,10 +353,7 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{ block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01}, M01_{M01},
...@@ -536,6 +536,9 @@ struct DeviceBatchedGemmXdl ...@@ -536,6 +536,9 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -550,6 +553,9 @@ struct DeviceBatchedGemmXdl ...@@ -550,6 +553,9 @@ struct DeviceBatchedGemmXdl
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1, 1,
1, 1,
a_element_op, a_element_op,
...@@ -570,6 +576,9 @@ struct DeviceBatchedGemmXdl ...@@ -570,6 +576,9 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -584,6 +593,9 @@ struct DeviceBatchedGemmXdl ...@@ -584,6 +593,9 @@ struct DeviceBatchedGemmXdl
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1, 1,
1, 1,
a_element_op, a_element_op,
......
...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M, int M,
int N, int N,
int K, int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -148,6 +154,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -148,6 +154,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
......
...@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M; const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
bool pass = ck::profiler:: bool pass = ck::profiler::
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>( profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification, do_verification,
...@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[]) ...@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M, M,
N, N,
K, K,
(StrideA < 0) ? DefaultStrideA : StrideA, BatchStrideA,
(StrideB < 0) ? DefaultStrideB : StrideB, BatchStrideB,
(StrideC < 0) ? DefaultStrideC : StrideC, BatchStrideC,
StrideA_,
StrideB_,
StrideC_,
BatchCount); BatchCount);
return pass ? 0 : 1; return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment