Commit 6b51413b authored by coderfeli's avatar coderfeli
Browse files

compile ok

parent b755f375
...@@ -516,10 +516,6 @@ include_directories(BEFORE ...@@ -516,10 +516,6 @@ include_directories(BEFORE
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -27,14 +27,14 @@ using S = ck::Sequence<Is...>; ...@@ -27,14 +27,14 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using FP8 = ck::f8_t; // using F16 = ck::f8_t;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = FP8; using A0DataType = F16;
using B0DataType = FP8; using B0DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
...@@ -98,9 +98,9 @@ struct MultiplyMultiply ...@@ -98,9 +98,9 @@ struct MultiplyMultiply
} }
}; };
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl)
{ {
int KPack = 16; int KPack = 8;
int NLane = NXdl; int NLane = NXdl;
int KLane = 64 / NLane; int KLane = 64 / NLane;
...@@ -145,20 +145,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -145,20 +145,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RCR ///###### RCR
// kernel 1: 256->32x128x128 // kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 128, 256, 32, 128, 128,
16, 16, 8, 8,
32, 32, 32, 32,
1, 1, 1, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// clang-format on // clang-format on
...@@ -230,9 +233,19 @@ int main(int argc, char* argv[]) ...@@ -230,9 +233,19 @@ int main(int argc, char* argv[])
return HostTensorDescriptor({row, col}, {1_uz, stride}); return HostTensorDescriptor({row, col}, {1_uz, stride});
} }
}; };
const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({M}, {1}));
for (int i = 0; i < experts; i++) {
expert_ids.mData[i] = i;
}
for (int i = 0; i < M; i++) {
sorted_token_ids.mData[i] = i % (M / 2);
}
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled( Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
...@@ -267,12 +280,16 @@ int main(int argc, char* argv[]) ...@@ -267,12 +280,16 @@ int main(int argc, char* argv[])
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
...@@ -297,7 +314,9 @@ int main(int argc, char* argv[]) ...@@ -297,7 +314,9 @@ int main(int argc, char* argv[])
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
...@@ -321,7 +340,7 @@ int main(int argc, char* argv[]) ...@@ -321,7 +340,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -83,7 +83,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -83,7 +83,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId() % 8));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -100,7 +100,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -100,7 +100,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId() % 8));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......
...@@ -157,7 +157,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -157,7 +157,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0; float ave_time = 0;
...@@ -249,30 +249,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -249,30 +249,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// Tail number always full // Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1) // if(arg.KBatch > 1)
{ // {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< // const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
Run(kernel); // Run(kernel);
} // }
else // else
{ // {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< // const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Even>; // TailNumber::Even>;
Run(kernel); // Run(kernel);
} // }
} // }
else // else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -296,175 +296,64 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -296,175 +296,64 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
} }
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) // else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ // {
if(arg.KBatch > 1) // if(arg.KBatch > 1)
{ // {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
Run(kernel); // Run(kernel);
} // }
else // else
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Even>; // TailNumber::Even>;
Run(kernel); // Run(kernel);
} // }
} // }
else // else
{ // {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum::Set,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
Run(kernel); // Run(kernel);
} // }
else // else
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum::Set,
minimum_occupancy, // minimum_occupancy,
TailNumber::Even>; // TailNumber::Even>;
Run(kernel); // Run(kernel);
} // }
} // }
} // }
else else
{ {
throw std::runtime_error("todo: only v1 & v2 support now"); throw std::runtime_error("todo: only v1 & v2 support now");
} }
} }
#if 0
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
throw std::runtime_error("todo: only v3 support now");
}
}
#endif
return ave_time; return ave_time;
} }
...@@ -517,7 +406,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -517,7 +406,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const void* p_a, static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids,
const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
...@@ -533,7 +424,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -533,7 +424,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{static_cast<const ADataType*>(p_a), return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
...@@ -553,7 +446,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -553,7 +446,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
...@@ -569,7 +463,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -569,7 +463,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(nullptr, nullptr,
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment