Commit f9abcf80 authored by coderfeli's avatar coderfeli
Browse files

use offsets in transfer ok

parent e947d11e
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp"
namespace ck { namespace ck {
...@@ -41,14 +41,15 @@ template <typename ThreadGroup, ...@@ -41,14 +41,15 @@ template <typename ThreadGroup,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun,
index_t GatherDim = 1,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1_mod8 struct ThreadGroupTensorSliceTransfer_v4r1_mod8
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
// using GatherIndex = MultiIndex<gather_num>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8(
const SrcDesc& src_desc, const SrcDesc& src_desc,
...@@ -56,13 +57,15 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ...@@ -56,13 +57,15 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
const SrcElementwiseOperation& src_element_op, const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op) const DstElementwiseOperation& dst_element_op,
const StaticallyIndexedArray<index_t, gather_num> &gather_offsets)
: threadwise_transfer_(src_desc, : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
src_element_op, src_element_op,
dst_desc, dst_desc,
make_zero_multi_index<nDim>(), make_zero_multi_index<nDim>(),
dst_element_op) dst_element_op,
gather_offsets)
{ {
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
...@@ -173,7 +176,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ...@@ -173,7 +176,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths), ThreadwiseTensorSliceTransfer_v3r1_gather<decltype(thread_slice_lengths),
SrcElementwiseOperation, SrcElementwiseOperation,
DstElementwiseOperation, DstElementwiseOperation,
DstInMemOp, DstInMemOp,
...@@ -191,6 +194,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ...@@ -191,6 +194,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8
DstScalarStrideInVector, DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun,
GatherDim,
NumThreadScratch>; NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
......
...@@ -1132,8 +1132,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1132,8 +1132,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto MLoadRepeats = MPerBlock / MLoadThreads; constexpr auto MLoadRepeats = MPerBlock / MLoadThreads;
static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads;
StaticallyIndexedArray<index_t, MLoadRepeats> token_offsets; //= p_sorted_token_ids[token_pos];
index_t token_offset = p_sorted_token_ids[token_pos]; static_for<0, MLoadRepeats, 1>{}([&](auto m0) {
token_offsets(m0) = p_sorted_token_ids[token_pos + MLoadThreads * m0] * problem.K;
});
printf("threadIdx.x %d off %d\n", threadIdx.x, token_offsets(I0));
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
...@@ -1183,13 +1186,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1183,13 +1186,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
1,
BlockwiseGemmPipe::GlobalBufferNum>( BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, token_offset, 0), make_multi_index(0, 0, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{},
token_offsets);
// Thread-wise copy // Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment