"examples/community/latent_consistency_interpolate.py" did not exist on "27062c3631b7011a5df45782b8e3d01349d1f3e9"
Commit 1c5b6f7d authored by letaoqin's avatar letaoqin
Browse files

gridwise add multiple d

parent c7bf4232
......@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -543,6 +544,80 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto ds_thread_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType,
c_thread_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize(),
true>{};
},
Number<NumDTensor>{});
auto ds_threadwise_copy = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
DDataType,
DDataType,
decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
Sequence<I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
1,
false>(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])); // register number
},
Number<NumDTensor>{});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
ds_grid_buf[i],
c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
ds_thread_buf(i));
});
static_for<0, c_thread_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize(), 1>{}(
[&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& { return ds_thread_buf[iSrc][i]; },
Number<NumDTensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return c_thread_buf(i); },
Number<2>{});
unpack2(cde_element_op, dst_data_refs, src_data_refs);
});
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment