Commit d47461d7 authored by aska-0096's avatar aska-0096
Browse files

Add compute-friendly pipeline for bpreshuffle case; remove enable-post-misched=0 flag.

parent cee23c47
...@@ -227,13 +227,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) ...@@ -227,13 +227,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000)
add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") add_compile_options("SHELL: -mllvm --lsr-drop-solution=1")
endif() endif()
endif() endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
if(HAS_ENABLE_POST_MISCHED)
message("Adding the enable-post-misched=0 compiler flag")
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
endif()
endif()
set(check-coerce) set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
......
...@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult ...@@ -2,6 +2,6 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker) # target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
...@@ -149,14 +149,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -149,14 +149,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 512, 128, 256, 256, 128,
16, 16, 16, 16,
32, 32, 32, 32,
1, 4, 4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v2, FP8>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
namespace ck { namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer, template <BlockGemmPipelineVersion BlkGemmPipelineVer,
...@@ -76,6 +77,30 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector() ...@@ -76,6 +77,30 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
NRepeat, NRepeat,
KPack>{}; KPack>{};
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else else
{ {
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
......
...@@ -249,6 +249,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -249,6 +249,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// Tail number always full // Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -295,9 +296,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -295,9 +296,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
Run(kernel); Run(kernel);
} }
} }
#endif
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -348,6 +359,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -348,6 +359,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
Run(kernel); Run(kernel);
} }
} }
#endif
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
} }
else else
{ {
...@@ -359,6 +378,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -359,6 +378,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{ {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
#if 0
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -405,8 +425,29 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -405,8 +425,29 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
Run(kernel); Run(kernel);
} }
} }
#endif
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
......
...@@ -172,7 +172,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -172,7 +172,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl; static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NWave * warpSize == BlockSize);
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
...@@ -1202,7 +1201,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1202,7 +1201,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
2>( BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
...@@ -1221,13 +1220,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1221,13 +1220,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>, Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>, Sequence<1, 2, 0, 3>,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled, true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id(), get_warp_local_1d_id() % NWave,
0, 0,
KPack * (get_thread_local_1d_id() % warpSize))); KPack * (get_thread_local_1d_id() % warpSize)));
...@@ -1661,13 +1660,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1661,13 +1660,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>, Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>, Sequence<1, 2, 0, 3>,
3, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled, true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id(), get_warp_local_1d_id() % NWave,
0, 0,
KPack * (get_thread_local_1d_id() % warpSize))); KPack * (get_thread_local_1d_id() % warpSize)));
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment