Commit e4ca61f9 authored by coderfeli's avatar coderfeli
Browse files

moe gemm2 scales ok

parent 66d08ea3
...@@ -132,7 +132,6 @@ static constexpr ck::index_t EVec = 16 / sizeof(EDataType); ...@@ -132,7 +132,6 @@ static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1; static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off // clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...@@ -185,7 +184,7 @@ int main(int argc, char* argv[]) ...@@ -185,7 +184,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock; ck::index_t sorted_tile_size = MPerBlock;
ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size;
ck::index_t tokens = 32; ck::index_t tokens = 64;
if(argc == 1) if(argc == 1)
{ {
...@@ -234,14 +233,13 @@ int main(int argc, char* argv[]) ...@@ -234,14 +233,13 @@ int main(int argc, char* argv[])
else else
sorted_token_ids.mData[i] = tokens; sorted_token_ids.mData[i] = tokens;
} }
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, N}, {1, 0})); Tensor<D2DataType> d2_e_n(HostTensorDescriptor({SORTED_SIZE, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
...@@ -285,6 +283,11 @@ int main(int argc, char* argv[]) ...@@ -285,6 +283,11 @@ int main(int argc, char* argv[])
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt"); a0_m_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_m_n.savetxt("d0_m_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
......
...@@ -1382,10 +1382,12 @@ struct GridwiseMoeGemmScatter ...@@ -1382,10 +1382,12 @@ struct GridwiseMoeGemmScatter
// ascale M, 1; bscale E, N, 1, move ptr to E // ascale M, 1; bscale E, N, 1, move ptr to E
if (i.value == 1) if (i.value == 1)
{ {
ptr_ += expert_id * problem.StrideDs[1] * problem.N; ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
} }
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
......
...@@ -174,7 +174,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -174,7 +174,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]); src_coords_[i]);
oob_val = oob_val & is_src_valid; oob_val = oob_val & is_src_valid;
if (i.value == 2) if (i.value == 3)
{ {
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec"); static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{}); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
...@@ -187,6 +187,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -187,6 +187,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using DataType = remove_cvref_t<decltype(data_types[i])>; using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp = const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true); src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(i.value == 2)
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp); // printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}( static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; }); [&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
......
...@@ -89,6 +89,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -89,6 +89,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const int t = arg.sorted_token_ids_(m); const int t = arg.sorted_token_ids_(m);
const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0]; const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0];
D2DataType v_topk_w = arg.d2_(m, 0); //expert
if(t < token_cnt) { if(t < token_cnt) {
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
...@@ -120,9 +121,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ...@@ -120,9 +121,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0}; CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b D0DataType v_d1 = arg.d1_(e, n); // b
D0DataType v_d2 = arg.d2_(e, 0); //expert arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_d2);
arg.c_t_n_(t, n) += v_c; arg.c_t_n_(t, n) += v_c;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment