Commit 6adf3591 authored by Jing Zhang's avatar Jing Zhang
Browse files

add batch_stride

parent fa9a0a5c
......@@ -26,6 +26,9 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -334,6 +334,9 @@ struct DeviceBatchedGemmXdl
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -350,10 +353,7 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01},
......@@ -536,6 +536,9 @@ struct DeviceBatchedGemmXdl
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......@@ -550,6 +553,9 @@ struct DeviceBatchedGemmXdl
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1,
1,
a_element_op,
......@@ -570,6 +576,9 @@ struct DeviceBatchedGemmXdl
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......@@ -584,6 +593,9 @@ struct DeviceBatchedGemmXdl
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
1,
1,
a_element_op,
......
......@@ -34,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M,
int N,
int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA,
int StrideB,
int StrideC,
......@@ -45,25 +48,28 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1}));
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride}));
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
......@@ -148,6 +154,9 @@ bool profile_batched_gemm_impl(int do_verification,
StrideA,
StrideB,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
......
......@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA;
const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB;
const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC;
const int BatchStrideA = (ck::is_same_v<ALayout, Row> ? M : K) * StrideA_;
const int BatchStrideB = (ck::is_same_v<BLayout, Row> ? K : N) * StrideB_;
const int BatchStrideC = (ck::is_same_v<CLayout, Row> ? M : N) * StrideC_;
bool pass = ck::profiler::
profile_batched_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification,
......@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
StrideA_,
StrideB_,
StrideC_,
BatchCount);
return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment