Commit b02c0b82 authored by coderfeli's avatar coderfeli
Browse files

gemm1 scale debug

parent e4ca61f9
......@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using B0DataType = F8;
using EDataType = F16;
using EDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
......@@ -68,9 +68,11 @@ struct MulABScale
const float& d1,
const D2DataType& d2) const
{
// const float x0_f = c * d0 * d1;
(void)d0; (void)d1; (void)d2;
(void)d2; // for gate, no d2 needed
(void)d0;
(void)d1;
const float x0_f = c;
// const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f);
}
};
......@@ -91,8 +93,10 @@ struct MulABScaleSiluMulGate
const D2DataType& d2) const
{
// act
(void)d0;
(void)d1;
(void)d2;
float x0;
float x0 = 0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
// fuse mul
e = ck::type_convert<EDataType>(x0);
......@@ -145,7 +149,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
......@@ -208,7 +212,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 512;
ck::index_t tokens = 32;
if(argc == 1)
{
......@@ -236,6 +240,8 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
// ck::index_t StrideD = 0;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
......@@ -261,15 +267,15 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
......@@ -281,21 +287,21 @@ int main(int argc, char* argv[])
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
......@@ -303,7 +309,7 @@ int main(int argc, char* argv[])
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt");
......@@ -311,7 +317,7 @@ int main(int argc, char* argv[])
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
......@@ -319,8 +325,6 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{};
// do GEMM
......@@ -404,7 +408,7 @@ int main(int argc, char* argv[])
const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_m_n(m, n), d2_m_n(m, n));
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(m, n), d2_m_n(m, n));
}
}
......
......@@ -43,13 +43,15 @@ template <typename ThreadGroup,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r3_scatter
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix
static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}) ; // Dirty HACK FELIX, TODO fix
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
......@@ -114,7 +116,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
Number<nSrc>{});
const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() % mod_num));
make_multi_index( OutputScatter ? ThreadGroup::GetThreadId() % mod_num : ThreadGroup::GetThreadId()));
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; },
Number<nDst>{});
......@@ -219,6 +221,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags,
ScatterDim,
OutputScatter,
ScatterWeightIdx,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
......
......@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp"
#define DEBUG_LOG 0
......@@ -486,13 +486,36 @@ struct GridwiseMoeGemmGather
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <typename DLayout>
__host__ __device__ static auto
MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
......@@ -509,8 +532,6 @@ struct GridwiseMoeGemmGather
Number<NumDTensor>{});
}
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
struct Problem
{
__host__ __device__ Problem(index_t NumTokens_,
......@@ -1159,10 +1180,6 @@ struct GridwiseMoeGemmGather
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -1377,8 +1394,18 @@ struct GridwiseMoeGemmGather
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType *ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
if (i.value == 1)
{
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
......@@ -1411,11 +1438,23 @@ struct GridwiseMoeGemmGather
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2];
static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = 0;
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
......@@ -1428,7 +1467,7 @@ struct GridwiseMoeGemmGather
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferCluster,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
......@@ -1440,13 +1479,21 @@ struct GridwiseMoeGemmGather
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
1, //ScatterDim
false, //OutputScatter: false, only use scatter weights
1 // ScatterWeightIdx: ascale
>
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
c_element_op,
scatter_offsets,
scatter_weights};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1472,7 +1519,6 @@ struct GridwiseMoeGemmGather
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// printf("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
......
......@@ -532,7 +532,6 @@ struct GridwiseMoeGemmScatter
Number<NumDTensor>{});
}
struct Problem
{
__host__ __device__ Problem(index_t NumTokens_,
......@@ -1427,15 +1426,14 @@ struct GridwiseMoeGemmScatter
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
// static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[token_pos];
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2];
static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N;
scatter_weights(m0) = p_sorted_weights[token_pos + m0];
scatter_offsets(m0) = (p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.N;
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
......
......@@ -44,6 +44,8 @@ template <typename SrcDatas,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t ScatterDim = 1,
bool OutputScatter = true,
index_t ScatterWeightIdx = 3,
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{
......@@ -174,7 +176,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]);
oob_val = oob_val & is_src_valid;
if (i.value == 3)
if (i.value == ScatterWeightIdx)
{
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
......@@ -187,8 +189,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(i.value == 2)
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
}
......@@ -420,8 +420,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
auto scatter_offset = 0;
if constexpr (OutputScatter)
{
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
const auto scatter_offset = scatter_offsets_(Number<iScatter>{});
scatter_offset = scatter_offsets_(Number<iScatter>{});
}
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
......@@ -459,7 +463,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = i.value != ScatterDim ? forward_step[i] : 0;
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
......@@ -530,7 +534,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = i.value != ScatterDim ? reset_step[Number<i>{}] : 0;
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
......
......@@ -49,8 +49,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
{
}
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_m_n_;
......@@ -58,7 +59,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t sorted_tile_size_;
};
// Invoker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment