Commit 6b51413b authored by coderfeli's avatar coderfeli
Browse files

compile ok

parent b755f375
...@@ -516,10 +516,6 @@ include_directories(BEFORE ...@@ -516,10 +516,6 @@ include_directories(BEFORE
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -27,14 +27,14 @@ using S = ck::Sequence<Is...>; ...@@ -27,14 +27,14 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using FP8 = ck::f8_t; // using F16 = ck::f8_t;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = FP8; using A0DataType = F16;
using B0DataType = FP8; using B0DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
...@@ -98,9 +98,9 @@ struct MultiplyMultiply ...@@ -98,9 +98,9 @@ struct MultiplyMultiply
} }
}; };
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl)
{ {
int KPack = 16; int KPack = 8;
int NLane = NXdl; int NLane = NXdl;
int KLane = 64 / NLane; int KLane = 64 / NLane;
...@@ -145,20 +145,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -145,20 +145,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RCR ///###### RCR
// kernel 1: 256->32x128x128 // kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 128, 256, 32, 128, 128,
16, 16, 8, 8,
32, 32, 32, 32,
1, 1, 1, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// clang-format on // clang-format on
...@@ -230,9 +233,19 @@ int main(int argc, char* argv[]) ...@@ -230,9 +233,19 @@ int main(int argc, char* argv[])
return HostTensorDescriptor({row, col}, {1_uz, stride}); return HostTensorDescriptor({row, col}, {1_uz, stride});
} }
}; };
const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({M}, {1}));
for (int i = 0; i < experts; i++) {
expert_ids.mData[i] = i;
}
for (int i = 0; i < M; i++) {
sorted_token_ids.mData[i] = i % (M / 2);
}
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled( Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
...@@ -267,12 +280,16 @@ int main(int argc, char* argv[]) ...@@ -267,12 +280,16 @@ int main(int argc, char* argv[])
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
...@@ -297,7 +314,9 @@ int main(int argc, char* argv[]) ...@@ -297,7 +314,9 @@ int main(int argc, char* argv[])
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
...@@ -321,7 +340,7 @@ int main(int argc, char* argv[]) ...@@ -321,7 +340,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -52,7 +52,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -52,7 +52,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op, const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin, const Index& dst_block_slice_origin,
...@@ -83,7 +83,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -83,7 +83,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId() % 8));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
...@@ -100,7 +100,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -100,7 +100,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId())); make_multi_index(ThreadGroup::GetThreadId() % 8));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
......
...@@ -157,7 +157,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -157,7 +157,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0; float ave_time = 0;
...@@ -249,30 +249,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -249,30 +249,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// Tail number always full // Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1) // if(arg.KBatch > 1)
{ // {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< // const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
Run(kernel); // Run(kernel);
} // }
else // else
{ // {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< // const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Even>; // TailNumber::Even>;
Run(kernel); // Run(kernel);
} // }
} // }
else // else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -296,175 +296,64 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -296,175 +296,64 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
} }
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) // else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ // {
if(arg.KBatch > 1) // if(arg.KBatch > 1)
{ // {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
Run(kernel); // Run(kernel);
} // }
else // else
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::AtomicAdd, // InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, // minimum_occupancy,
TailNumber::Even>; // TailNumber::Even>;
Run(kernel); // Run(kernel);
} // }
} // }
else // else
{ // {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum::Set,
minimum_occupancy, // minimum_occupancy,
TailNumber::Odd>; // TailNumber::Odd>;
Run(kernel); // Run(kernel);
} // }
else // else
{ // {
const auto kernel = // const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< // kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm, // GridwiseGemm,
true, // true,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum::Set,
minimum_occupancy, // minimum_occupancy,
TailNumber::Even>; // TailNumber::Even>;
Run(kernel); // Run(kernel);
} // }
} // }
} // }
else else
{ {
throw std::runtime_error("todo: only v1 & v2 support now"); throw std::runtime_error("todo: only v1 & v2 support now");
} }
} }
#if 0
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
throw std::runtime_error("todo: only v3 support now");
}
}
#endif
return ave_time; return ave_time;
} }
...@@ -517,7 +406,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -517,7 +406,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const void* p_a, static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids,
const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
...@@ -533,7 +424,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -533,7 +424,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{static_cast<const ADataType*>(p_a), return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
...@@ -553,7 +446,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -553,7 +446,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
...@@ -569,7 +463,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -569,7 +463,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(nullptr, nullptr,
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
p_ds, p_ds,
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
......
...@@ -44,6 +44,8 @@ __global__ void ...@@ -44,6 +44,8 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
...@@ -173,6 +175,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -173,6 +175,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t NLane = NPerXdl; static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NWave * warpSize == BlockSize); static_assert(NWave * warpSize == BlockSize);
static constexpr index_t Experts = 8;
static constexpr index_t SortedTileSize = 32;
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
...@@ -189,9 +194,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -189,9 +194,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); return std::make_tuple(math::integer_divide_ceil(N, NPerBlock),
math::integer_divide_ceil(M, MPerBlock),
1);
} }
__host__ __device__ static auto CalculateMPadded(index_t M) __host__ __device__ static auto CalculateMPadded(index_t M)
...@@ -477,45 +484,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -477,45 +484,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_right_pad_transform(N, NPad - N)), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
} }
__host__ __device__ static auto MakeDsGridDescriptor_M_N( __host__ __device__ static auto MakeDsGridDescriptor_M_N(
...@@ -617,7 +585,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -617,7 +585,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem struct Argument : public tensor_operation::device::BaseArgument, public Problem
{ {
__host__ Argument(const ADataType* p_a_grid_, __host__ Argument(
const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,
const ADataType* p_a_grid_,
const BDataType* p_b_grid_, const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_, std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_, CDataType* p_c_grid_,
...@@ -633,6 +604,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -633,6 +604,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_ds_grid{}, p_ds_grid{},
...@@ -651,6 +625,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -651,6 +625,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}); });
} }
const index_t * p_sorted_token_ids;
const index_t * p_sorted_expert_ids;
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
DsGridPointer p_ds_grid; DsGridPointer p_ds_grid;
...@@ -1107,12 +1083,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1107,12 +1083,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942 // if arch = gfx942
using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd> TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid, __device__ static void Run(
const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
DsGridPointer& p_ds_grid, DsGridPointer& p_ds_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
...@@ -1121,35 +1100,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1121,35 +1100,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_c_grid,
p_shared,
problem,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
template <typename Block2CTileMap,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
{ {
ignore = b_element_op; ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
...@@ -1164,34 +1114,34 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1164,34 +1114,34 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); constexpr auto MLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto KLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0) * ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto MLoadRepeats = MPerBlock / MLoadThreads;
static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads;
const auto block_work_idx = index_t token_offset = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[token_pos]);
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack // N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -1199,7 +1149,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1199,7 +1149,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// dummy // dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
...@@ -1225,7 +1174,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1225,7 +1174,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true, true,
BlockwiseGemmPipe::GlobalBufferNum>( BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, token_offset, 0),
a_element_op, a_element_op,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1432,7 +1381,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1432,7 +1381,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_tuple(make_multi_index(0, 0, 0, 0)), make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple( generate_tuple(
[&](auto) { [&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); return make_multi_index(block_m_id, 0, block_n_id, 0);
// return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
}, },
Number<NumDTensor>{})); Number<NumDTensor>{}));
...@@ -1557,19 +1507,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1557,19 +1507,19 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( // Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_a_grid, // p_a_grid,
p_b_grid, // p_b_grid,
p_ds_grid, // p_ds_grid,
p_c_grid, // p_c_grid,
p_shared, // p_shared,
p_shared1, // p_shared1,
problem, // problem,
a_element_op, // a_element_op,
b_element_op, // b_element_op,
c_element_op, // c_element_op,
block_2_ctile_map); // block_2_ctile_map);
} }
template <typename Block2CTileMap, template <typename Block2CTileMap,
...@@ -1588,402 +1538,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1588,402 +1538,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id(),
0,
KPack * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment