Unverified Commit 35e49f2d authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

add g; fixed strides (#355)

parent de60d290
add_example_executable(example_gemm_bias_e_permute_m3n2_xdl_fp16 gemm_bias_e_permute_m3n2_xdl_fp16.cpp) add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_m2n3_xdl_fp16 gemm_bias_e_permute_m2n3_xdl_fp16.cpp) add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
...@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -500,11 +500,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std::array<long_index_t, NumDTensor> ds_offset; std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if constexpr(NumDimG > 0) ds_offset[i] =
ds_offset[i] = ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0));
else
ds_offset[i] = 0;
}); });
return ds_offset; return ds_offset;
...@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -512,10 +509,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
if constexpr(NumDimG > 0) return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
else
return 0;
} }
private: private:
...@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -634,6 +628,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{
a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, "");
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment