Commit 59f3e009 authored by coderfeli's avatar coderfeli
Browse files

remove d2 for gemm1

parent 418baed3
...@@ -40,73 +40,56 @@ using AccDataType = F32; ...@@ -40,73 +40,56 @@ using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using D1DataType = F32; using D1DataType = F32;
using D2DataType = EDataType; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
// using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Col;
using ELayout = Row; using ELayout = Row;
using D0Layout = Row; using D0Layout = Row;
using D1Layout = Col; using D1Layout = Col;
using D2Layout = ELayout; using DsLayout = ck::Tuple<D0Layout, D1Layout>;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// for gate, a_scale, b_scale // for gate, a_scale, b_scale
struct MulABScale struct MulABScale
{ {
template <typename E, typename C, typename D0, typename D1, typename D2> template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType> __host__ __device__ constexpr void operator()<EDataType, float, float, float>
(EDataType& e, (EDataType& e,
const float& c, const float& c,
const float& d0, const float& d0,
const float& d1, const float& d1) const
const D2DataType& d2) const
{ {
(void)d2; // for gate, no d2 needed e = ck::type_convert<EDataType>(c * d1 * d0);
(void)d0;
(void)d1;
const float x0_f = c * d1 * d0;
// const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f);
} }
}; };
// for gate, a_scale, b_scale, fuse silu, // for gate, a_scale, b_scale, fuse silu,
struct MulABScaleSiluMulGate struct MulABScaleSilu
{ {
template <typename E, typename C, typename D0, typename D1, typename D2> template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType> __host__ __device__ constexpr void operator()<EDataType, float, float>
(EDataType& e, (EDataType& e,
const float& c, const float& c,
const float& d0, const float& d0,
const float& d1, const float& d1) const
const D2DataType& d2) const
{ {
// act // act
(void)d0;
(void)d1;
(void)d2;
float x0 = 0; float x0 = 0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
// fuse mul
e = ck::type_convert<EDataType>(x0); e = ck::type_convert<EDataType>(x0);
} }
}; };
// using DsLayout = DsLayoutGate; // using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate; // using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MulABScale; using CDEElementOp = MulABScale;
...@@ -158,7 +141,6 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); ...@@ -158,7 +141,6 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off // clang-format off
...@@ -188,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -188,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>, 1, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...@@ -241,7 +223,7 @@ int main(int argc, char* argv[]) ...@@ -241,7 +223,7 @@ int main(int argc, char* argv[])
// ck::index_t StrideD = 0; // ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size(); constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0}; constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
...@@ -269,14 +251,12 @@ int main(int argc, char* argv[]) ...@@ -269,14 +251,12 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
...@@ -288,32 +268,27 @@ int main(int argc, char* argv[]) ...@@ -288,32 +268,27 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3}); d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 3});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{1, 3});
break; break;
case 2: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
d0_t_n.savetxt("d0_t_n.txt", "int"); d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int"); d1_e_n.savetxt("d1_e_n.txt", "int");
d2_m_n.savetxt("d2_m_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
...@@ -321,7 +296,6 @@ int main(int argc, char* argv[]) ...@@ -321,7 +296,6 @@ int main(int argc, char* argv[])
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data()); e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -344,8 +318,7 @@ int main(int argc, char* argv[]) ...@@ -344,8 +318,7 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(), a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
SORTED_SIZE, SORTED_SIZE,
...@@ -410,7 +383,7 @@ int main(int argc, char* argv[]) ...@@ -410,7 +383,7 @@ int main(int argc, char* argv[])
const int e = expert_ids(m / sorted_tile_size); const int e = expert_ids(m / sorted_tile_size);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n), d2_m_n(m, n)); cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment