Commit dc1c2bf8 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into letaoqin/update_layernorm

parents 5cfd751b a285d6f9
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::PDataType, BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN1, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kN1,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1WarpTile>, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
...@@ -46,7 +46,7 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< ...@@ -46,7 +46,7 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 64, 16, 8, 32, 32, 3, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 64, 16, 8, 32, 32, 3, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// We prefer following instance, however, existing compiler bug cause it failed to generate sanity code. // We prefer following instance, however, existing compiler bug cause it failed to generate sanity code.
// DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -18,4 +18,9 @@ if(result EQUAL 0) ...@@ -18,4 +18,9 @@ if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility) target_link_libraries(test_bf8 PRIVATE utility)
endif() endif()
add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0)
target_link_libraries(test_custom_type PRIVATE utility)
endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_type_convert_const type_convert_const.cpp)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment