Commit a8a82e0c authored by coderfeli's avatar coderfeli
Browse files

fix warnings and impl scale for gemm2, build ok

parent 69f54ee8
......@@ -152,6 +152,9 @@ static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this const
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
......@@ -181,7 +184,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1, EVec>,
1, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......
......@@ -40,9 +40,8 @@ using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = EDataType;
// using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
......@@ -51,35 +50,39 @@ using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
struct MultiplyMultiply
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
//gpu
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
(EDataType& e,
const float& c,
const float& d0,
const float& d1,
const D2DataType& d2) const
const float& d2) const
{
// const float x0_f = c * d0 * d1;
(void)d0; (void)d1; (void)d2;
const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f);
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
// for reference
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>
(float& e,
const float& c,
const float& d0,
const float& d1,
const float& d2) const
{
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MultiplyMultiply;
using CDEElementOp = MulABScaleExpertWeight;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
......@@ -115,7 +118,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32;
......@@ -126,6 +129,9 @@ static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
......@@ -155,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1, EVec>,
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......@@ -232,16 +238,16 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({SORTED_SIZE, N}, {0, 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {0, 0}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, 1}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
......@@ -252,38 +258,38 @@ int main(int argc, char* argv[])
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
......@@ -358,26 +364,26 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
B0DataType,
D0DataType,
D1DataType,
D2DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
CDEElementOp>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{});
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
// const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_m_n(t, n), d2_m_n(t, n));
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
......
......@@ -535,7 +535,7 @@ struct DeviceMoeGemm
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op) override
{
// assert(0, "no impl");
return std::make_unique<Argument>(nullptr, nullptr,
......
......@@ -901,9 +901,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
SrcCoord src_coord_;
DstCoord dst_coord_;
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
};
} // namespace ck
......@@ -687,10 +687,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
};
} // namespace ck
......@@ -17,6 +17,9 @@ namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename D0DataType,
typename D1DataType,
typename D2DataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -33,6 +36,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1,
const Tensor<D2DataType>& d2,
Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -42,6 +48,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k},
b_e_n_k_{b_e_n_k},
d0_{d0},
d1_{d1},
d2_{d2},
c_t_n_{c_t_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -49,16 +58,19 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{
}
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_e_n_k_;
const Tensor<D0DataType>& d0_;
const Tensor<D1DataType>& d1_;
const Tensor<D2DataType>& d2_;
Tensor<CDataType>& c_t_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
index_t sorted_tile_size_;
};
// Invoker
......@@ -106,8 +118,10 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b
D0DataType v_d2 = arg.d2_(e, 0); //expert
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_d2);
arg.c_t_n_(t, n) += v_c;
}
......@@ -140,12 +154,15 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1,
const Tensor<D2DataType>& d2,
Tensor<CDataType>& c_t_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op};
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment