Commit 69f54ee8 authored by coderfeli's avatar coderfeli
Browse files

impl 3ds epilog ok

parent 72752420
...@@ -35,51 +35,79 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -35,51 +35,79 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8; using A0DataType = F8;
using B0DataType = F8; using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using D1DataType = F32; using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D2DataType = EDataType;
using EDataType = F16; // using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row; using D0Layout = Row;
using D1Layout = Col; using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>; using D2Layout = ELayout;
using ELayout = Row; // using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
struct MultiplyMultiply // for gate, a_scale, b_scale
struct MulABScale
{ {
template <typename E, typename C, typename D0, typename D1> template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const; operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(EDataType& e, __host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
(EDataType& e,
const float& c, const float& c,
const float& d0, const float& d0,
const float& d1) const const float& d1,
const D2DataType& d2) const
{ {
// const float x0_f = c * d0 * d1; // const float x0_f = c * d0 * d1;
const float x0_f = c; (void)d0; (void)d1; (void)d2;
// printf("epi %f\n", c); const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f); e = ck::type_convert<EDataType>(x0_f);
} }
};
// template <> // for gate, a_scale, b_scale, fuse silu,
// __host__ __device__ constexpr void operator()<BF16, float, float, float>(BF16& e, struct MulABScaleSiluMulGate
// const float& c, {
// const float& d0, template <typename E, typename C, typename D0, typename D1, typename D2>
// const float& d1) const __host__ __device__ constexpr void
// { operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// const float x0_f = c;
// // const float x0_f = c * d0 * d1; template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
// e = ck::type_convert<BF16>(x0_f); (EDataType& e,
// } const float& c,
const float& d0,
const float& d1,
const D2DataType& d2) const
{
// act
(void)d2;
float x0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
// fuse mul
e = ck::type_convert<EDataType>(x0);
}
}; };
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MulABScale;
// using CDEElementOp = MulABScaleSiluMulGate;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{ {
...@@ -115,10 +143,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -115,10 +143,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
...@@ -142,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -142,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// ak1, bk1 // ak1, bk1
AK1, BK1, AK1, BK1,
// mn_perxdl // mn_perxdl
32, 32, MNPerXDL, MNPerXDL,
// mn_xdlperwave // mn_xdlperwave
MXDLPerWave, 1, MXDLPerWave, 1,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
...@@ -153,7 +181,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -153,7 +181,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1>, 1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1, EVec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...@@ -203,26 +231,12 @@ int main(int argc, char* argv[]) ...@@ -203,26 +231,12 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
ck::index_t StrideD = 0; // ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1}));
...@@ -244,17 +258,16 @@ int main(int argc, char* argv[]) ...@@ -244,17 +258,16 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
// Tensor<B0DataType> b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
// Tensor<B0DataType> b0_preshuffled( Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
// f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D0DataType> d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_host_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); Tensor<EDataType> e_m_n_device_result(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
...@@ -265,33 +278,38 @@ int main(int argc, char* argv[]) ...@@ -265,33 +278,38 @@ int main(int argc, char* argv[])
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_t_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_t_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_t_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt"); a0_t_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_t_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data()); e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -318,7 +336,8 @@ int main(int argc, char* argv[]) ...@@ -318,7 +336,8 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(), a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
SORTED_SIZE, SORTED_SIZE,
...@@ -326,7 +345,7 @@ int main(int argc, char* argv[]) ...@@ -326,7 +345,7 @@ int main(int argc, char* argv[])
K, K,
StrideA, StrideA,
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0}, std::array<ck::index_t, NumDTensor>{I0, I0, I0},
StrideE, StrideE,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -382,7 +401,7 @@ int main(int argc, char* argv[]) ...@@ -382,7 +401,7 @@ int main(int argc, char* argv[])
const int t = sorted_token_ids(m); const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_t_n(t, n)); cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_m_n(m, n), d2_m_n(m, n));
} }
} }
......
...@@ -35,51 +35,51 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -35,51 +35,51 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8; using A0DataType = F8;
using B0DataType = F8; using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using D1DataType = F32; using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D2DataType = EDataType;
using EDataType = F16; // using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row; using D0Layout = Row;
using D1Layout = Col; using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>; using D2Layout = ELayout;
using ELayout = Row; // using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
struct MultiplyMultiply struct MultiplyMultiply
{ {
template <typename E, typename C, typename D0, typename D1> template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const; operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <> template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(EDataType& e, __host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
(EDataType& e,
const float& c, const float& c,
const float& d0, const float& d0,
const float& d1) const const float& d1,
const D2DataType& d2) const
{ {
// const float x0_f = c * d0 * d1; // const float x0_f = c * d0 * d1;
const float x0_f = c; (void)d0; (void)d1; (void)d2;
// printf("epi %f\n", c); const float x0_f = c;
e = ck::type_convert<EDataType>(x0_f); e = ck::type_convert<EDataType>(x0_f);
} }
// template <>
// __host__ __device__ constexpr void operator()<BF16, float, float, float>(BF16& e,
// const float& c,
// const float& d0,
// const float& d1) const
// {
// const float x0_f = c;
// // const float x0_f = c * d0 * d1;
// e = ck::type_convert<BF16>(x0_f);
// }
}; };
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using DsLayout = DsLayoutUp;
using DsDataType = DsDataTypeUp;
using CDEElementOp = MultiplyMultiply;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{ {
...@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm ...@@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1>, CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1, EVec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
...@@ -205,26 +205,12 @@ int main(int argc, char* argv[]) ...@@ -205,26 +205,12 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = K;
ck::index_t StrideD = 0; // ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
// const ck::index_t experts = 8; // const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1})); Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1}));
...@@ -246,17 +232,16 @@ int main(int argc, char* argv[]) ...@@ -246,17 +232,16 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
// Tensor<B0DataType> b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{})); Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
// Tensor<B0DataType> b0_preshuffled( Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
// f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
Tensor<D0DataType> d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero(); e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
...@@ -267,33 +252,38 @@ int main(int argc, char* argv[]) ...@@ -267,33 +252,38 @@ int main(int argc, char* argv[])
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_t_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break; break;
case 2: case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_t_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_t_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt"); a0_m_k.savetxt("a.txt");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_t_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
d2_device_buf.ToDevice(d2_m_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -320,7 +310,8 @@ int main(int argc, char* argv[]) ...@@ -320,7 +310,8 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(), a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
tokens, tokens,
SORTED_SIZE, SORTED_SIZE,
...@@ -328,7 +319,7 @@ int main(int argc, char* argv[]) ...@@ -328,7 +319,7 @@ int main(int argc, char* argv[])
K, K,
StrideA, StrideA,
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0}, std::array<ck::index_t, NumDTensor>{I0, I0, I0},
StrideE, StrideE,
KBatch, KBatch,
a_element_op, a_element_op,
...@@ -386,7 +377,7 @@ int main(int argc, char* argv[]) ...@@ -386,7 +377,7 @@ int main(int argc, char* argv[])
// const int t = sorted_token_ids(m); // const int t = sorted_token_ids(m);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n)); cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_m_n(t, n), d2_m_n(t, n));
} }
} }
......
...@@ -301,7 +301,6 @@ struct DeviceMoeGemm ...@@ -301,7 +301,6 @@ struct DeviceMoeGemm
// Tail number always full // Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
using meme
// if(arg.KBatch > 1) // if(arg.KBatch > 1)
// { // {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -435,7 +434,7 @@ struct DeviceMoeGemm ...@@ -435,7 +434,7 @@ struct DeviceMoeGemm
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return -1;//Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -60,39 +60,39 @@ __global__ void ...@@ -60,39 +60,39 @@ __global__ void
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, // template <typename GridwiseGemm,
bool HasMainKBlockLoop, // bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, // InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, // index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Even> // TailNumber TailNum = TailNumber::Even>
__global__ void // __global__ void
#if CK_USE_LAUNCH_BOUNDS // #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif // #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg) // kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg)
{ // {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( // GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, // karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, // karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, // karg.p_ds_grid,
karg.p_c_grid, // karg.p_c_grid,
p_shared, // p_shared,
p_shared1, // p_shared1,
karg, // karg,
karg.a_element_op, // karg.a_element_op,
karg.b_element_op, // karg.b_element_op,
karg.c_element_op); // karg.c_element_op);
#else // #else
ignore = karg; // ignore = karg;
#endif // end of if (defined(__gfx9__)) // #endif // end of if (defined(__gfx9__))
} // }
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -1143,8 +1143,8 @@ struct GridwiseMoeGemmGather ...@@ -1143,8 +1143,8 @@ struct GridwiseMoeGemmGather
gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K; gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
}); });
const index_t m_block_data_idx_on_grid = // const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); // __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack // N0, K0, Blocksize*KPack
...@@ -1515,52 +1515,52 @@ struct GridwiseMoeGemmGather ...@@ -1515,52 +1515,52 @@ struct GridwiseMoeGemmGather
} }
} }
template <bool HasMainKBlockLoop, // template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, // InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> // TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid, // __device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid, // const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, // DsGridPointer& p_ds_grid,
CDataType* p_c_grid, // CDataType* p_c_grid,
void* p_shared, // void* p_shared,
void* p_shared1, // void* p_shared1,
const Problem& problem, // const Problem& problem,
AElementwiseOperation a_element_op, // AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, // BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) // CElementwiseOperation c_element_op)
{ // {
// const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; // // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( // // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// p_a_grid, // // p_a_grid,
// p_b_grid, // // p_b_grid,
// p_ds_grid, // // p_ds_grid,
// p_c_grid, // // p_c_grid,
// p_shared, // // p_shared,
// p_shared1, // // p_shared1,
// problem, // // problem,
// a_element_op, // // a_element_op,
// b_element_op, // // b_element_op,
// c_element_op, // // c_element_op,
// block_2_ctile_map); // // block_2_ctile_map);
} // }
template <typename Block2CTileMap, // template <typename Block2CTileMap,
bool HasMainKBlockLoop, // bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, // InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> // TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid, // __device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid, // const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, // DsGridPointer& p_ds_grid,
CDataType* p_c_grid, // CDataType* p_c_grid,
void* p_shared, // void* p_shared,
void* p_shared1, // void* p_shared1,
const Problem& problem, // const Problem& problem,
AElementwiseOperation a_element_op, // AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, // BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, // CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map) // const Block2CTileMap& block_2_ctile_map)
{ // {
} // }
}; };
} // namespace ck } // namespace ck
...@@ -60,39 +60,39 @@ __global__ void ...@@ -60,39 +60,39 @@ __global__ void
#endif // end of if (defined(__gfx9__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, // template <typename GridwiseGemm,
bool HasMainKBlockLoop, // bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, // InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, // index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Even> // TailNumber TailNum = TailNumber::Even>
__global__ void // __global__ void
#if CK_USE_LAUNCH_BOUNDS // #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif // #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_moe_gemm_scatter_2lds(typename GridwiseGemm::Argument karg) // kernel_moe_gemm_scatter_2lds(typename GridwiseGemm::Argument karg)
{ // {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( // GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, // karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, // karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, // karg.p_ds_grid,
karg.p_c_grid, // karg.p_c_grid,
p_shared, // p_shared,
p_shared1, // p_shared1,
karg, // karg,
karg.a_element_op, // karg.a_element_op,
karg.b_element_op, // karg.b_element_op,
karg.c_element_op); // karg.c_element_op);
#else // #else
ignore = karg; // ignore = karg;
#endif // end of if (defined(__gfx9__)) // #endif // end of if (defined(__gfx9__))
} // }
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -1506,52 +1506,52 @@ struct GridwiseMoeGemmScatter ...@@ -1506,52 +1506,52 @@ struct GridwiseMoeGemmScatter
} }
} }
template <bool HasMainKBlockLoop, // template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, // InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> // TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid, // __device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid, // const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, // DsGridPointer& p_ds_grid,
CDataType* p_c_grid, // CDataType* p_c_grid,
void* p_shared, // void* p_shared,
void* p_shared1, // void* p_shared1,
const Problem& problem, // const Problem& problem,
AElementwiseOperation a_element_op, // AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, // BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) // CElementwiseOperation c_element_op)
{ // {
// const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; // // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
// Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( // // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// p_a_grid, // // p_a_grid,
// p_b_grid, // // p_b_grid,
// p_ds_grid, // // p_ds_grid,
// p_c_grid, // // p_c_grid,
// p_shared, // // p_shared,
// p_shared1, // // p_shared1,
// problem, // // problem,
// a_element_op, // // a_element_op,
// b_element_op, // // b_element_op,
// c_element_op, // // c_element_op,
// block_2_ctile_map); // // block_2_ctile_map);
} // }
template <typename Block2CTileMap, // template <typename Block2CTileMap,
bool HasMainKBlockLoop, // bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, // InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> // TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid, // __device__ static void Run_2Lds(const ADataType* p_a_grid,
const BDataType* p_b_grid, // const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, // DsGridPointer& p_ds_grid,
CDataType* p_c_grid, // CDataType* p_c_grid,
void* p_shared, // void* p_shared,
void* p_shared1, // void* p_shared1,
const Problem& problem, // const Problem& problem,
AElementwiseOperation a_element_op, // AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, // BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, // CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map) // const Block2CTileMap& block_2_ctile_map)
{ // {
} // }
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment