Commit 69f54ee8 authored by coderfeli's avatar coderfeli
Browse files

impl 3ds epilog ok

parent 72752420
......@@ -35,51 +35,79 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using D2DataType = EDataType;
// using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
using D2Layout = ELayout;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
struct MultiplyMultiply
// for gate, a_scale, b_scale
struct MulABScale
{
template <typename E, typename C, typename D0, typename D1>
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(EDataType& e,
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
const float& d1,
const D2DataType& d2) const
{
// const float x0_f = c * d0 * d1;
const float x0_f = c;
// printf("epi %f\n", c);
(void)d0; (void)d1; (void)d2;
const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f);
}
};
// template <>
// __host__ __device__ constexpr void operator()<BF16, float, float, float>(BF16& e,
// const float& c,
// const float& d0,
// const float& d1) const
// {
// const float x0_f = c;
// // const float x0_f = c * d0 * d1;
// e = ck::type_convert<BF16>(x0_f);
// }
// for gate, a_scale, b_scale, fuse silu,
struct MulABScaleSiluMulGate
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
(EDataType& e,
const float& c,
const float& d0,
const float& d1,
const D2DataType& d2) const
{
// act
(void)d2;
float x0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
// fuse mul
e = ck::type_convert<EDataType>(x0);
}
};
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MulABScale;
// using CDEElementOp = MulABScaleSiluMulGate;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
......@@ -115,10 +143,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
......@@ -142,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// ak1, bk1
AK1, BK1,
// mn_perxdl
32, 32,
MNPerXDL, MNPerXDL,
// mn_xdlperwave
MXDLPerWave, 1,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
......@@ -153,7 +181,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1>,
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1, EVec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......@@ -203,26 +231,12 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
// ck::index_t StrideD = 0;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1}));
......@@ -244,17 +258,16 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
// Tensor<B0DataType> b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{}));
// Tensor<B0DataType> b0_preshuffled(
// f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
......@@ -265,33 +278,38 @@ int main(int argc, char* argv[])
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_t_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_t_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_t_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_t_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
......@@ -318,7 +336,8 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
SORTED_SIZE,
......@@ -326,7 +345,7 @@ int main(int argc, char* argv[])
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
std::array<ck::index_t, NumDTensor>{I0, I0, I0},
StrideE,
KBatch,
a_element_op,
......@@ -382,7 +401,7 @@ int main(int argc, char* argv[])
const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_t_n(t, n));
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_m_n(m, n), d2_m_n(m, n));
}
}
......
......@@ -35,51 +35,51 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using D2DataType = EDataType;
// using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
using D2Layout = ELayout;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(EDataType& e,
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
const float& d1,
const D2DataType& d2) const
{
// const float x0_f = c * d0 * d1;
const float x0_f = c;
// printf("epi %f\n", c);
(void)d0; (void)d1; (void)d2;
const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f);
}
// template <>
// __host__ __device__ constexpr void operator()<BF16, float, float, float>(BF16& e,
// const float& c,
// const float& d0,
// const float& d1) const
// {
// const float x0_f = c;
// // const float x0_f = c * d0 * d1;
// e = ck::type_convert<BF16>(x0_f);
// }
};
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MultiplyMultiply;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
......@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1>,
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1, EVec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
......@@ -205,26 +205,12 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
// ck::index_t StrideD = 0;
ck::index_t StrideE = N;
ck::index_t KBatch = 1;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1}));
......@@ -246,17 +232,16 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
// Tensor<B0DataType> b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{}));
// Tensor<B0DataType> b0_preshuffled(
// f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
......@@ -267,33 +252,38 @@ int main(int argc, char* argv[])
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_t_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_t_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_t_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_t_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
......@@ -320,7 +310,8 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
SORTED_SIZE,
......@@ -328,7 +319,7 @@ int main(int argc, char* argv[])
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
std::array<ck::index_t, NumDTensor>{I0, I0, I0},
StrideE,
KBatch,
a_element_op,
......@@ -386,7 +377,7 @@ int main(int argc, char* argv[])
// const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n)
{
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n));
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_m_n(t, n), d2_m_n(t, n));
}
}
......
......@@ -301,7 +301,6 @@ struct DeviceMoeGemm
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
using meme
// if(arg.KBatch > 1)
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
......@@ -435,7 +434,7 @@ struct DeviceMoeGemm
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return -1;//Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -60,39 +60,39 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
// template <typename GridwiseGemm,
// bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// index_t MinimumOccupancy = 1,
// TailNumber TailNum = TailNumber::Even>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// #endif
// // __attribute__((amdgpu_waves_per_eu(1, 1)))
// kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg)
// {
// #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_ds_grid,
// karg.p_c_grid,
// p_shared,
// p_shared1,
// karg,
// karg.a_element_op,
// karg.b_element_op,
// karg.c_element_op);
// #else
// ignore = karg;
// #endif // end of if (defined(__gfx9__))
// }
template <typename ALayout,
typename BLayout,
......@@ -1143,8 +1143,8 @@ struct GridwiseMoeGemmGather
gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
// const index_t m_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack
......@@ -1515,52 +1515,52 @@ struct GridwiseMoeGemmGather
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
// const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// p_a_grid,
// p_b_grid,
// p_ds_grid,
// p_c_grid,
// p_shared,
// p_shared1,
// problem,
// a_element_op,
// b_element_op,
// c_element_op,
// block_2_ctile_map);
}
template <typename Block2CTileMap,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
}
// template <bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
// void* p_shared,
// void* p_shared1,
// const Problem& problem,
// AElementwiseOperation a_element_op,
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op)
// {
// // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// // p_a_grid,
// // p_b_grid,
// // p_ds_grid,
// // p_c_grid,
// // p_shared,
// // p_shared1,
// // problem,
// // a_element_op,
// // b_element_op,
// // c_element_op,
// // block_2_ctile_map);
// }
// template <typename Block2CTileMap,
// bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
// void* p_shared,
// void* p_shared1,
// const Problem& problem,
// AElementwiseOperation a_element_op,
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op,
// const Block2CTileMap& block_2_ctile_map)
// {
// }
};
} // namespace ck
......@@ -60,39 +60,39 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_scatter_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
// template <typename GridwiseGemm,
// bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// index_t MinimumOccupancy = 1,
// TailNumber TailNum = TailNumber::Even>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// #endif
// // __attribute__((amdgpu_waves_per_eu(1, 1)))
// kernel_moe_gemm_scatter_2lds(typename GridwiseGemm::Argument karg)
// {
// #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_ds_grid,
// karg.p_c_grid,
// p_shared,
// p_shared1,
// karg,
// karg.a_element_op,
// karg.b_element_op,
// karg.c_element_op);
// #else
// ignore = karg;
// #endif // end of if (defined(__gfx9__))
// }
template <typename ALayout,
typename BLayout,
......@@ -1506,52 +1506,52 @@ struct GridwiseMoeGemmScatter
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
// const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// p_a_grid,
// p_b_grid,
// p_ds_grid,
// p_c_grid,
// p_shared,
// p_shared1,
// problem,
// a_element_op,
// b_element_op,
// c_element_op,
// block_2_ctile_map);
}
template <typename Block2CTileMap,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
}
// template <bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
// void* p_shared,
// void* p_shared1,
// const Problem& problem,
// AElementwiseOperation a_element_op,
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op)
// {
// // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// // p_a_grid,
// // p_b_grid,
// // p_ds_grid,
// // p_c_grid,
// // p_shared,
// // p_shared1,
// // problem,
// // a_element_op,
// // b_element_op,
// // c_element_op,
// // block_2_ctile_map);
// }
// template <typename Block2CTileMap,
// bool HasMainKBlockLoop,
// InMemoryDataOperationEnum CGlobalMemoryDataOperation,
// TailNumber TailNum = TailNumber::Odd>
// __device__ static void Run_2Lds(const ADataType* p_a_grid,
// const BDataType* p_b_grid,
// DsGridPointer& p_ds_grid,
// CDataType* p_c_grid,
// void* p_shared,
// void* p_shared1,
// const Problem& problem,
// AElementwiseOperation a_element_op,
// BElementwiseOperation b_element_op,
// CElementwiseOperation c_element_op,
// const Block2CTileMap& block_2_ctile_map)
// {
// }
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment