Commit a8a82e0c authored by coderfeli's avatar coderfeli
Browse files

fix warnings and impl scale for gemm2, build ok

parent 69f54ee8
...@@ -152,6 +152,9 @@ static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this const ...@@ -152,6 +152,9 @@ static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this const
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off // clang-format off
...@@ -181,7 +184,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -181,7 +184,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1, EVec>, 1, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......
...@@ -40,9 +40,8 @@ using AccDataType = F32; ...@@ -40,9 +40,8 @@ using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using D1DataType = F32; using D1DataType = F32;
using D2DataType = EDataType; using D2DataType = F32;
// using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Col;
...@@ -51,35 +50,39 @@ using D0Layout = Row; ...@@ -51,35 +50,39 @@ using D0Layout = Row;
using D1Layout = Col; using D1Layout = Col;
using D2Layout = ELayout; using D2Layout = ELayout;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>; // using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>; using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
struct MultiplyMultiply // d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{ {
template <typename E, typename C, typename D0, typename D1, typename D2> template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
//gpu
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType> __host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
(EDataType& e, (EDataType& e,
const float& c, const float& c,
const float& d0, const float& d0,
const float& d1, const float& d1,
const D2DataType& d2) const const float& d2) const
{ {
// const float x0_f = c * d0 * d1; e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void)d0; (void)d1; (void)d2; }
const float x0_f = c; // for reference
e = ck::type_convert<EDataType>(x0_f); template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>
(float& e,
const float& c,
const float& d0,
const float& d1,
const float& d2) const
{
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
} }
}; };
// using DsLayout = DsLayoutGate; using CDEElementOp = MulABScaleExpertWeight;
// using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MultiplyMultiply;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{ {
...@@ -115,7 +118,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -115,7 +118,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply; using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32; static constexpr ck::index_t MPerBlock = 32;
...@@ -126,6 +129,9 @@ static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; ...@@ -126,6 +129,9 @@ static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off // clang-format off
...@@ -155,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -155,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1, EVec>, CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...@@ -232,16 +238,16 @@ int main(int argc, char* argv[]) ...@@ -232,16 +238,16 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({SORTED_SIZE, N}, {0, 0}));
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {0, 0}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, 1}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
...@@ -252,38 +258,38 @@ int main(int argc, char* argv[]) ...@@ -252,38 +258,38 @@ int main(int argc, char* argv[])
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0}); d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt"); a0_m_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data()); d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -358,26 +364,26 @@ int main(int argc, char* argv[]) ...@@ -358,26 +364,26 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
B0DataType, B0DataType,
D0DataType,
D1DataType,
D2DataType,
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>; CDEElementOp>;
auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument( auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{}); sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t) for(int t = 0; t < tokens; ++t)
{ {
// const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_m_n(t, n), d2_m_n(t, n)); e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
} }
} }
......
...@@ -535,7 +535,7 @@ struct DeviceMoeGemm ...@@ -535,7 +535,7 @@ struct DeviceMoeGemm
index_t KBatch, index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op) override
{ {
// assert(0, "no impl"); // assert(0, "no impl");
return std::make_unique<Argument>(nullptr, nullptr, return std::make_unique<Argument>(nullptr, nullptr,
......
...@@ -901,9 +901,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ...@@ -901,9 +901,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
const SrcElementwiseOperation src_element_op_; const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_; const DstElementwiseOperation dst_element_op_;
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
}; };
} // namespace ck } // namespace ck
...@@ -687,10 +687,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -687,10 +687,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>; using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_; StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
SrcCoords src_coords_; SrcCoords src_coords_;
DstCoords dst_coords_; DstCoords dst_coords_;
const ElementwiseOperation element_op_; const ElementwiseOperation element_op_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
}; };
} // namespace ck } // namespace ck
...@@ -17,6 +17,9 @@ namespace host { ...@@ -17,6 +17,9 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename D0DataType,
typename D1DataType,
typename D2DataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -33,6 +36,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -33,6 +36,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1,
const Tensor<D2DataType>& d2,
Tensor<CDataType>& c_t_n, Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -42,6 +48,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -42,6 +48,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
sorted_tile_size_{sorted_tile_size}, sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k}, a_m_k_{a_m_k},
b_e_n_k_{b_e_n_k}, b_e_n_k_{b_e_n_k},
d0_{d0},
d1_{d1},
d2_{d2},
c_t_n_{c_t_n}, c_t_n_{c_t_n},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
...@@ -49,16 +58,19 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -49,16 +58,19 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{ {
} }
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& sorted_token_ids_; const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_m_k_; const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_e_n_k_; const Tensor<BDataType>& b_e_n_k_;
const Tensor<D0DataType>& d0_;
const Tensor<D1DataType>& d1_;
const Tensor<D2DataType>& d2_;
Tensor<CDataType>& c_t_n_; Tensor<CDataType>& c_t_n_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t sorted_tile_size_;
}; };
// Invoker // Invoker
...@@ -106,8 +118,10 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -106,8 +118,10 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
CDataType v_c{0}; CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a
arg.c_element_op_(v_c, v_acc); D0DataType v_d1 = arg.d1_(e, n); // b
D0DataType v_d2 = arg.d2_(e, 0); //expert
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_d2);
arg.c_t_n_(t, n) += v_c; arg.c_t_n_(t, n) += v_c;
} }
...@@ -140,12 +154,15 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -140,12 +154,15 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const index_t sorted_tile_size, const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k, const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1,
const Tensor<D2DataType>& d2,
Tensor<CDataType>& c_t_n, Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op}; return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment