Commit e4ca61f9 authored by coderfeli's avatar coderfeli
Browse files

moe gemm2 scales ok

parent 66d08ea3
......@@ -132,7 +132,6 @@ static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
......@@ -185,7 +184,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 32;
ck::index_t tokens = 64;
if(argc == 1)
{
......@@ -234,14 +233,13 @@ int main(int argc, char* argv[])
else
sorted_token_ids.mData[i] = tokens;
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, N}, {1, 0}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({SORTED_SIZE, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
......@@ -285,6 +283,11 @@ int main(int argc, char* argv[])
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_m_n.savetxt("d0_m_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data());
......
......@@ -1382,10 +1382,12 @@ struct GridwiseMoeGemmScatter
// ascale M, 1; bscale E, N, 1, move ptr to E
if (i.value == 1)
{
ptr_ += expert_id * problem.StrideDs[1] * problem.N;
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
......
......@@ -174,7 +174,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]);
oob_val = oob_val & is_src_valid;
if (i.value == 2)
if (i.value == 3)
{
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
......@@ -187,6 +187,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(i.value == 2)
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
......
......@@ -89,6 +89,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0];
D2DataType v_topk_w = arg.d2_(m, 0); //expert
if(t < token_cnt) {
for(int k = 0; k < K; ++k)
......@@ -120,9 +121,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b
D0DataType v_d2 = arg.d2_(e, 0); //expert
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_d2);
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
arg.c_t_n_(t, n) += v_c;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment