Commit 987eff1b authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed GetDsPtrOffset/GetEPtrOffset with long_index

parent 5d7ab929
...@@ -228,7 +228,7 @@ int main(int argc, char* argv[]) ...@@ -228,7 +228,7 @@ int main(int argc, char* argv[])
// E[G, N0, M0, N1, M1, N2] // E[G, N0, M0, N1, M1, N2]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G, M0, M1, N0, N1, N2}; std::vector<ck::index_t> e_gs_ms_ns_lengths{G, M0, M1, N0, N1, N2};
std::vector<ck::index_t> e_gs_ms_ns_strides{ std::vector<ck::index_t> e_gs_ms_ns_strides{
M0 * M1 * N0 * N1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1}; N0 * M0 * N1 * M1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1};
if(argc == 1) if(argc == 1)
{ {
...@@ -257,9 +257,6 @@ int main(int argc, char* argv[]) ...@@ -257,9 +257,6 @@ int main(int argc, char* argv[])
Tensor<DDataType> d_gs_ms_ns( Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result( Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
...@@ -267,7 +264,7 @@ int main(int argc, char* argv[]) ...@@ -267,7 +264,7 @@ int main(int argc, char* argv[])
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_device_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -359,9 +356,26 @@ int main(int argc, char* argv[]) ...@@ -359,9 +356,26 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_gs_ms_ns_host_result( const ck::index_t G_ = 1;
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), const ck::index_t N0_ = 3;
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
// A[G, M0, M1, K0]
std::vector<ck::index_t> host_a_gs_ms_ks_lengths{G_, M0, M1, K0};
std::vector<ck::index_t> host_a_gs_ms_ks_strides{M0 * M1 * K0, M1 * K0, K0, 1};
// B[G, N0_, N1, N2, K0]
std::vector<ck::index_t> host_b_gs_ns_ks_lengths{G_, N0_, N1, N2, K0};
std::vector<ck::index_t> host_b_gs_ns_ks_strides{
N0_ * N1 * N2 * K0, N1 * N2 * K0, N2 * K0, K0, 1};
// D[G_, N0_, M0, N1, M1, N2]
std::vector<ck::index_t> host_d_gs_ms_ns_lengths{G_, M0, M1, N0_, N1, N2};
std::vector<ck::index_t> host_d_gs_ms_ns_strides{N0_ * N1 * N2, 0, 0, N1 * N2, N2, 1};
// E[G_, N0_, M0, N1, M1, N2]
std::vector<ck::index_t> host_e_gs_ms_ns_lengths{G_, M0, M1, N0_, N1, N2};
std::vector<ck::index_t> host_e_gs_ms_ns_strides{
N0_ * M0 * N1 * M1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1};
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM, using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN, NumDimN,
...@@ -377,8 +391,44 @@ int main(int argc, char* argv[]) ...@@ -377,8 +391,44 @@ int main(int argc, char* argv[])
auto ref_gemm = ReferenceOpInstance{}; auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks, Tensor<ADataType> host_a_gs_ms_ks(std::vector<std::size_t>(host_a_gs_ms_ks_lengths.begin(),
b_gs_ns_ks, host_a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(host_a_gs_ms_ks_strides.begin(),
host_a_gs_ms_ks_strides.end()));
Tensor<BDataType> host_b_gs_ns_ks(std::vector<std::size_t>(host_b_gs_ns_ks_lengths.begin(),
host_b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(host_b_gs_ns_ks_strides.begin(),
host_b_gs_ns_ks_strides.end()));
Tensor<DDataType> host_d_gs_ms_ns(std::vector<std::size_t>(host_d_gs_ms_ns_lengths.begin(),
host_d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(host_d_gs_ms_ns_strides.begin(),
host_d_gs_ms_ns_strides.end()));
std::copy(a_gs_ms_ks.mData.begin(), a_gs_ms_ks.mData.end(), host_a_gs_ms_ks.begin());
std::copy(b_gs_ns_ks.mData.begin(), b_gs_ns_ks.mData.end(), host_b_gs_ns_ks.begin());
std::copy(d_gs_ms_ns.mData.begin(), d_gs_ms_ns.mData.end(), host_d_gs_ms_ns.begin());
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(host_e_gs_ms_ns_lengths.begin(),
host_e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(host_e_gs_ms_ns_strides.begin(),
host_e_gs_ms_ns_strides.end()));
std::cout << "host_a_gs_ms_ks: " << host_a_gs_ms_ks.mDesc << std::endl;
std::cout << "host_b_gs_ns_ks: " << host_b_gs_ns_ks.mDesc << std::endl;
std::cout << "host_d_gs_ms_ns: " << host_d_gs_ms_ns.mDesc << std::endl;
std::cout << "host_e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(host_e_gs_ms_ns_lengths.begin(),
host_e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(host_e_gs_ms_ns_strides.begin(),
host_e_gs_ms_ns_strides.end()));
auto ref_argument = ref_gemm.MakeArgument(host_a_gs_ms_ks,
host_b_gs_ns_ks,
c_gs_ms_ns_host_result, c_gs_ms_ns_host_result,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -401,7 +451,7 @@ int main(int argc, char* argv[]) ...@@ -401,7 +451,7 @@ int main(int argc, char* argv[])
{ {
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
d_gs_ms_ns(g0, m0, m1, n0, n1, n2)); host_d_gs_ms_ns(g0, m0, m1, n0, n1, n2));
} }
} }
} }
......
...@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(batch_stride_A_); return static_cast<long_index_t>(g_idx) * batch_stride_A_;
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(batch_stride_B_); return static_cast<long_index_t>(g_idx) * batch_stride_B_;
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
...@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std::array<long_index_t, NumDTensor> ds_offset; std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = ds_offset[i] = static_cast<long_index_t>(g_idx) *
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0)); ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
}); });
return ds_offset; return ds_offset;
...@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return static_cast<long_index_t>(g_idx) *
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
} }
private: private:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment