Commit b02c0b82 authored by coderfeli's avatar coderfeli
Browse files

gemm1 scale debug

parent e4ca61f9
...@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8; using A0DataType = F8;
using B0DataType = F8; using B0DataType = F8;
using EDataType = F16; using EDataType = F32;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
...@@ -68,9 +68,11 @@ struct MulABScale ...@@ -68,9 +68,11 @@ struct MulABScale
const float& d1, const float& d1,
const D2DataType& d2) const const D2DataType& d2) const
{ {
// const float x0_f = c * d0 * d1; (void)d2; // for gate, no d2 needed
(void)d0; (void)d1; (void)d2; (void)d0;
const float x0_f = c; (void)d1;
const float x0_f = c;
// const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f); e = ck::type_convert<EDataType>(x0_f);
} }
}; };
...@@ -91,8 +93,10 @@ struct MulABScaleSiluMulGate ...@@ -91,8 +93,10 @@ struct MulABScaleSiluMulGate
const D2DataType& d2) const const D2DataType& d2) const
{ {
// act // act
(void)d0;
(void)d1;
(void)d2; (void)d2;
float x0; float x0 = 0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
// fuse mul // fuse mul
e = ck::type_convert<EDataType>(x0); e = ck::type_convert<EDataType>(x0);
...@@ -145,7 +149,7 @@ using AElementOp = PassThrough; ...@@ -145,7 +149,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
...@@ -208,7 +212,7 @@ int main(int argc, char* argv[]) ...@@ -208,7 +212,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 512; ck::index_t tokens = 32;
if(argc == 1) if(argc == 1)
{ {
...@@ -236,6 +240,8 @@ int main(int argc, char* argv[]) ...@@ -236,6 +240,8 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K; ck::index_t StrideB = K;
// ck::index_t StrideD = 0; // ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
...@@ -261,15 +267,15 @@ int main(int argc, char* argv[]) ...@@ -261,15 +267,15 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl; std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
...@@ -281,21 +287,21 @@ int main(int argc, char* argv[]) ...@@ -281,21 +287,21 @@ int main(int argc, char* argv[])
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{}); d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0}); d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
...@@ -303,7 +309,7 @@ int main(int argc, char* argv[]) ...@@ -303,7 +309,7 @@ int main(int argc, char* argv[])
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
...@@ -311,7 +317,7 @@ int main(int argc, char* argv[]) ...@@ -311,7 +317,7 @@ int main(int argc, char* argv[])
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data()); d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data()); e_device_buf.ToDevice(e_m_n_device_result.mData.data());
...@@ -319,8 +325,6 @@ int main(int argc, char* argv[]) ...@@ -319,8 +325,6 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{}; constexpr auto I0 = ck::Number<0>{};
// do GEMM // do GEMM
...@@ -404,7 +408,7 @@ int main(int argc, char* argv[]) ...@@ -404,7 +408,7 @@ int main(int argc, char* argv[])
const int t = sorted_token_ids(m); const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_m_n(m, n), d2_m_n(m, n)); cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(m, n), d2_m_n(m, n));
} }
} }
......
...@@ -43,13 +43,15 @@ template <typename ThreadGroup, ...@@ -43,13 +43,15 @@ template <typename ThreadGroup,
typename ThreadTransferSrcResetCoordinateAfterRunFlags, typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags, typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1, index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3_scatter struct ThreadGroupTensorSliceTransfer_v7r3_scatter
{ {
static constexpr index_t nDim = static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension(); remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}) ; // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size(); static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size(); static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
...@@ -114,7 +116,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter ...@@ -114,7 +116,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
Number<nSrc>{}); Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num)); make_multi_index( OutputScatter ? ThreadGroup::GetThreadId() % mod_num : ThreadGroup::GetThreadId()));
const auto dst_thread_slice_origins = generate_tuple( const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; }, [&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
Number<nDst>{}); Number<nDst>{});
...@@ -219,6 +221,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter ...@@ -219,6 +221,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim, ScatterDim,
OutputScatter,
ScatterWeightIdx,
NumThreadScratch>; NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
#define DEBUG_LOG 0 #define DEBUG_LOG 0
...@@ -486,13 +486,36 @@ struct GridwiseMoeGemmGather ...@@ -486,13 +486,36 @@ struct GridwiseMoeGemmGather
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
template <typename DLayout>
__host__ __device__ static auto
MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N( __host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs) index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]); return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
...@@ -509,8 +532,6 @@ struct GridwiseMoeGemmGather ...@@ -509,8 +532,6 @@ struct GridwiseMoeGemmGather
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
struct Problem struct Problem
{ {
__host__ __device__ Problem(index_t NumTokens_, __host__ __device__ Problem(index_t NumTokens_,
...@@ -1158,10 +1179,6 @@ struct GridwiseMoeGemmGather ...@@ -1158,10 +1179,6 @@ struct GridwiseMoeGemmGather
// if(threadIdx.x==0) // if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n", // printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -1377,8 +1394,18 @@ struct GridwiseMoeGemmGather ...@@ -1377,8 +1394,18 @@ struct GridwiseMoeGemmGather
const auto ds_grid_buf = generate_tuple( const auto ds_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType *ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
if (i.value == 1)
{
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
...@@ -1411,11 +1438,23 @@ struct GridwiseMoeGemmGather ...@@ -1411,11 +1438,23 @@ struct GridwiseMoeGemmGather
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock; c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2];
static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = 0;
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>, Tuple<EDataType>,
...@@ -1428,7 +1467,7 @@ struct GridwiseMoeGemmGather ...@@ -1428,7 +1467,7 @@ struct GridwiseMoeGemmGather
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
...@@ -1440,13 +1479,21 @@ struct GridwiseMoeGemmGather ...@@ -1440,13 +1479,21 @@ struct GridwiseMoeGemmGather
Sequence<true>, Sequence<true>,
uniform_sequence_gen_t<NumDTensor, uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim
false, //OutputScatter: false, only use scatter weights
1 // ScatterWeightIdx: ascale
>
{c_ds_desc_refs, {c_ds_desc_refs,
idx_c_ds_block_begin, idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op}; c_element_op,
scatter_offsets,
scatter_weights};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -1472,7 +1519,6 @@ struct GridwiseMoeGemmGather ...@@ -1472,7 +1519,6 @@ struct GridwiseMoeGemmGather
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// printf("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n");
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
......
...@@ -532,7 +532,6 @@ struct GridwiseMoeGemmScatter ...@@ -532,7 +532,6 @@ struct GridwiseMoeGemmScatter
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
struct Problem struct Problem
{ {
__host__ __device__ Problem(index_t NumTokens_, __host__ __device__ Problem(index_t NumTokens_,
...@@ -1427,15 +1426,14 @@ struct GridwiseMoeGemmScatter ...@@ -1427,15 +1426,14 @@ struct GridwiseMoeGemmScatter
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads; constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
// static_assert(EMRepeats == 1, "only support 1 line per thread now!"); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme // too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2]; const float *p_sorted_weights = p_ds_grid[I2];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; scatter_offsets(m0) = (p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.N;
scatter_weights(m0) = p_sorted_weights[token_pos + m0]; scatter_weights(m0) = p_sorted_weights[c_token_pos + m0];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
}); });
......
...@@ -44,6 +44,8 @@ template <typename SrcDatas, ...@@ -44,6 +44,8 @@ template <typename SrcDatas,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...> typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...> typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t ScatterDim = 1, index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1> index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v7r3_scatter struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{ {
...@@ -174,7 +176,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -174,7 +176,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]); src_coords_[i]);
oob_val = oob_val & is_src_valid; oob_val = oob_val & is_src_valid;
if (i.value == 3) if (i.value == ScatterWeightIdx)
{ {
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec"); static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{}); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
...@@ -187,8 +189,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -187,8 +189,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using DataType = remove_cvref_t<decltype(data_types[i])>; using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp = const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true); src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(i.value == 2)
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}( static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; }); [&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
} }
...@@ -420,8 +420,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -420,8 +420,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// loop over space-filling curve // loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) { static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{}); auto scatter_offset = 0;
const auto scatter_offset = scatter_offsets_(Number<iScatter>{}); if constexpr (OutputScatter)
{
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
scatter_offset = scatter_offsets_(Number<iScatter>{});
}
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type; using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
...@@ -459,7 +463,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -459,7 +463,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
Index step_; Index step_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = i.value != ScatterDim ? forward_step[i] : 0; step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
// if(threadIdx.x==0) // if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
...@@ -530,7 +534,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -530,7 +534,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{ {
Index step_; Index step_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = i.value != ScatterDim ? reset_step[Number<i>{}] : 0; step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
// if(threadIdx.x==0) // if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
......
...@@ -49,8 +49,9 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -49,8 +49,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
{ {
} }
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& sorted_token_ids_; const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_; const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_m_n_; Tensor<CDataType>& c_m_n_;
...@@ -58,7 +59,6 @@ struct ReferenceMoeGemm : public device::BaseOperator ...@@ -58,7 +59,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t sorted_tile_size_;
}; };
// Invoker // Invoker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment