"...composable_kernel.git" did not exist on "63af525c06363f398b851967da2740a2ace382b5"
Commit 418baed3 authored by coderfeli's avatar coderfeli
Browse files

moe gemm1 scaleready

parent b02c0b82
...@@ -71,7 +71,7 @@ struct MulABScale ...@@ -71,7 +71,7 @@ struct MulABScale
(void)d2; // for gate, no d2 needed (void)d2; // for gate, no d2 needed
(void)d0; (void)d0;
(void)d1; (void)d1;
const float x0_f = c; const float x0_f = c * d1 * d0;
// const float x0_f = c; // const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f); e = ck::type_convert<EDataType>(x0_f);
} }
...@@ -286,9 +286,9 @@ int main(int argc, char* argv[]) ...@@ -286,9 +286,9 @@ int main(int argc, char* argv[])
case 1: case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2}); d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{1, 3});
break; break;
case 2: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
...@@ -304,6 +304,9 @@ int main(int argc, char* argv[]) ...@@ -304,6 +304,9 @@ int main(int argc, char* argv[])
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0}); d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
d2_m_n.savetxt("d2_m_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
...@@ -325,8 +328,6 @@ int main(int argc, char* argv[]) ...@@ -325,8 +328,6 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
constexpr auto I0 = ck::Number<0>{};
// do GEMM // do GEMM
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
...@@ -352,7 +353,7 @@ int main(int argc, char* argv[]) ...@@ -352,7 +353,7 @@ int main(int argc, char* argv[])
K, K,
StrideA, StrideA,
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0, I0}, StrideDs,
StrideE, StrideE,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -406,9 +407,10 @@ int main(int argc, char* argv[]) ...@@ -406,9 +407,10 @@ int main(int argc, char* argv[])
{ {
const int t = sorted_token_ids(m); const int t = sorted_token_ids(m);
const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(m, n), d2_m_n(m, n)); cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n), d2_m_n(m, n));
} }
} }
......
...@@ -1401,7 +1401,7 @@ struct GridwiseMoeGemmGather ...@@ -1401,7 +1401,7 @@ struct GridwiseMoeGemmGather
if (i.value == 1) if (i.value == 1)
{ {
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1); ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0) // if ( threadIdx.x % 16 ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]); // printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
} }
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1448,10 +1448,11 @@ struct GridwiseMoeGemmGather ...@@ -1448,10 +1448,11 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme // too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2]; const float *p_sorted_weights = p_ds_grid[I0];
static_for<0, EMRepeats, 1>{}([&](auto m0) { static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = 0; scatter_offsets(m0) = 0;
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0]; scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]];
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
}); });
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
......
...@@ -176,10 +176,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -176,10 +176,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_[i]); src_coords_[i]);
oob_val = oob_val & is_src_valid; oob_val = oob_val & is_src_valid;
if (i.value == ScatterWeightIdx) if (i.value == ScatterWeightIdx)
{ {
static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec"); static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{}); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number<iScatter>{}));
static_for<0, SrcScalarPerVector, 1>{}( static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); }); [&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); });
} }
...@@ -189,11 +191,15 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ...@@ -189,11 +191,15 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using DataType = remove_cvref_t<decltype(data_types[i])>; using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp = const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true); src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}( static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; }); [&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
} }
else else
{ {
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value);
src_vectors(i).template AsType<src_vector_t>()(I0) = src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true); src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment