Commit 66d08ea3 authored by coderfeli's avatar coderfeli
Browse files

impl topk weight scatter

parent a8a82e0c
...@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight ...@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight
template <typename E, typename C, typename D0, typename D1, typename D2> template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
//gpu //real kernel use
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float> __host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
(EDataType& e, (EDataType& e,
...@@ -211,8 +211,9 @@ int main(int argc, char* argv[]) ...@@ -211,8 +211,9 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
// ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
...@@ -238,9 +239,9 @@ int main(int argc, char* argv[]) ...@@ -238,9 +239,9 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({SORTED_SIZE, N}, {0, 0})); Tensor<D0DataType> d0_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {0, 0})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, 1}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
...@@ -248,7 +249,7 @@ int main(int argc, char* argv[]) ...@@ -248,7 +249,7 @@ int main(int argc, char* argv[])
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
...@@ -257,21 +258,21 @@ int main(int argc, char* argv[]) ...@@ -257,21 +258,21 @@ int main(int argc, char* argv[])
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
...@@ -279,7 +280,7 @@ int main(int argc, char* argv[]) ...@@ -279,7 +280,7 @@ int main(int argc, char* argv[])
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
...@@ -287,7 +288,7 @@ int main(int argc, char* argv[]) ...@@ -287,7 +288,7 @@ int main(int argc, char* argv[])
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data()); d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data());
...@@ -296,9 +297,6 @@ int main(int argc, char* argv[]) ...@@ -296,9 +297,6 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{};
// do GEMM // do GEMM
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
...@@ -325,7 +323,7 @@ int main(int argc, char* argv[]) ...@@ -325,7 +323,7 @@ int main(int argc, char* argv[])
K, K,
StrideA, StrideA,
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0, I0}, StrideDs,
StrideE, StrideE,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -375,7 +373,7 @@ int main(int argc, char* argv[]) ...@@ -375,7 +373,7 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_m_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -64,13 +64,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter ...@@ -64,13 +64,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const DstDescs& dst_descs, const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins, const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op, const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets) const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
: threadwise_transfer_(src_descs, : threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{}, StaticallyIndexedArray<Index, nSrc>{},
dst_descs, dst_descs,
StaticallyIndexedArray<Index, nDst>{}, StaticallyIndexedArray<Index, nDst>{},
element_op, element_op,
scatter_offsets) scatter_offsets,
scatter_weights)
{ {
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
......
...@@ -305,23 +305,43 @@ struct DeviceMoeGemm ...@@ -305,23 +305,43 @@ struct DeviceMoeGemm
// { // {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// { // {
// const auto kernel = kernel_moe_gemm_gather< // if constexpr (IsGatherGemm) {
// GridwiseGemm, // const auto kernel = kernel_moe_gemm_gather<
// true, // GridwiseGemm,
// InMemoryDataOperationEnum::AtomicAdd, // true,
// minimum_occupancy, // InMemoryDataOperationEnum::AtomicAdd,
// TailNumber::Odd>; // minimum_occupancy,
// RunKernel(kernel); // TailNumber::Odd>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// } // }
// else // else
// { // {
// const auto kernel = kernel_moe_gemm_gather< // if constexpr (IsGatherGemm) {
// GridwiseGemm, // const auto kernel = kernel_moe_gemm_gather<
// true, // GridwiseGemm,
// InMemoryDataOperationEnum::AtomicAdd, // true,
// minimum_occupancy, // InMemoryDataOperationEnum::AtomicAdd,
// TailNumber::Even>; // minimum_occupancy,
// RunKernel(kernel); // TailNumber::Even>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
// }
// } // }
// } // }
// else // else
......
...@@ -486,13 +486,36 @@ struct GridwiseMoeGemmScatter ...@@ -486,13 +486,36 @@ struct GridwiseMoeGemmScatter
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
template <typename DLayout>
__host__ __device__ static auto
MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N( __host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs) index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]); return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
...@@ -509,7 +532,6 @@ struct GridwiseMoeGemmScatter ...@@ -509,7 +532,6 @@ struct GridwiseMoeGemmScatter
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
struct Problem struct Problem
{ {
...@@ -1354,6 +1376,14 @@ struct GridwiseMoeGemmScatter ...@@ -1354,6 +1376,14 @@ struct GridwiseMoeGemmScatter
const auto ds_grid_buf = generate_tuple( const auto ds_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType *ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale M, 1; bscale E, N, 1, move ptr to E
if (i.value == 1)
{
ptr_ += expert_id * problem.StrideDs[1] * problem.N;
}
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
}, },
...@@ -1398,8 +1428,12 @@ struct GridwiseMoeGemmScatter ...@@ -1398,8 +1428,12 @@ struct GridwiseMoeGemmScatter
// static_assert(EMRepeats == 1, "only support 1 line per thread now!"); // static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos]; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N;
scatter_weights(m0) = p_sorted_weights[token_pos + m0];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
}); });
...@@ -1435,7 +1469,8 @@ struct GridwiseMoeGemmScatter ...@@ -1435,7 +1469,8 @@ struct GridwiseMoeGemmScatter
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)), make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op, c_element_op,
scatter_offsets}; scatter_offsets,
scatter_weights};
// if(threadIdx.x== 0) // if(threadIdx.x== 0)
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
......
...@@ -99,11 +99,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -99,11 +99,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const DstDescs& dst_descs, const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins, const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op, const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets) const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)), : src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op), element_op_(element_op),
scatter_offsets_(scatter_offsets) scatter_offsets_(scatter_offsets),
scatter_weights_(scatter_weights)
{ {
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! cannot evenly divide"); "wrong! cannot evenly divide");
...@@ -172,14 +174,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -172,14 +174,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]); src_coords_[i]);
oob_val = oob_val & is_src_valid; oob_val = oob_val & is_src_valid;
if (i.value == 2)
if constexpr(SrcScalarPerVectors{}[i] == 1) {
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); });
}
else if constexpr(SrcScalarPerVectors{}[i] == 1)
{ {
auto data_types = SrcDatas{}; auto data_types = SrcDatas{};
using DataType = remove_cvref_t<decltype(data_types[i])>; using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp = const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true); src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}( static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; }); [&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
} }
...@@ -691,6 +699,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -691,6 +699,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
DstCoords dst_coords_; DstCoords dst_coords_;
const ElementwiseOperation element_op_; const ElementwiseOperation element_op_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_; StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
StaticallyIndexedArray<float, scatter_num> scatter_weights_;
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment