Commit ad11d2a4 authored by Chao Liu's avatar Chao Liu
Browse files

fix

parent 2488d0bf
...@@ -503,8 +503,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -503,8 +503,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
static_assert(NumDTensor == 0, "wrong!");
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<decltype(DsDataType{}.At(i))>; using DDataType = tuple_element_t<i.value, DsDataType>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
......
...@@ -549,6 +549,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -549,6 +549,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{})); Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy // tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat( const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)), make_tuple(make_multi_index(0, 0, 0, 0)),
...@@ -561,9 +569,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -561,9 +569,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// blockwise copy C/D/E between LDS and global // blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
Tuple<FloatCShuffle, decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})),
remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>>,
Tuple<FloatE>, Tuple<FloatE>,
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
...@@ -633,7 +639,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -633,7 +639,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global // each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_ds_desc_refs, c_ds_desc_refs,
tie(c_shuffle_block_buf, ds_grid_buf[I0], ds_grid_buf[I1]), c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf)); tie(e_grid_buf));
......
...@@ -240,7 +240,13 @@ struct arithmetic_sequence_gen ...@@ -240,7 +240,13 @@ struct arithmetic_sequence_gen
} }
}; };
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
using type1 = Sequence<>;
static constexpr bool kHasContent =
(Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);
using type = typename conditional<kHasContent, type0, type1>::type;
}; };
// uniform sequence // uniform sequence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment