Commit bf210540 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Merge branch 'develop' into reference_gemm_alloc

parents 36d1b311 0e54d7ae
...@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy ...@@ -810,21 +810,46 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1() CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{ {
using S_ = typename Problem::BlockShape; using S_ = typename Problem::BlockShape;
using T_ = typename Problem::Traits;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> && if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> && std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> && std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{ {
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{}; return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
} }
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> && else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> && std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> && std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == false)
{ {
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{}; return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
T_::PipeInterleave == true)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
} }
} }
}; };
......
...@@ -22,7 +22,8 @@ template <bool IsGateOnly_, ...@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
FusedMoeGemmWeightPermuteEnum PermuteEnum_ = FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten, FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false, bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false> bool PadIntermediateSize_ = false,
bool PipeInterleave_ = true>
struct FusedMoeGemmTraits struct FusedMoeGemmTraits
{ {
// Gate+Up or Gate only // Gate+Up or Gate only
...@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits ...@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_; static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_; static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_; static constexpr bool PadIntermediateSize = PadIntermediateSize_;
static constexpr bool PipeInterleave = PipeInterleave_;
}; };
// Note: this need to be a bit mask // Note: this need to be a bit mask
......
...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs ...@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // output row_stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -58,14 +59,21 @@ struct Smoothquant ...@@ -58,14 +59,21 @@ struct Smoothquant
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // input row_stride
index_t y_stride; // out row_stride
}; };
using Hargs = SmoothquantHostArgs; using Hargs = SmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{ return Kargs{hargs.p_x,
hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride}; hargs.p_xscale,
hargs.p_yscale,
hargs.p_qy,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.y_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -116,7 +124,7 @@ struct Smoothquant ...@@ -116,7 +124,7 @@ struct Smoothquant
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -157,7 +165,7 @@ struct Smoothquant ...@@ -157,7 +165,7 @@ struct Smoothquant
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy), static_cast<QYDataType*>(kargs.p_qy),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = ...@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std ...@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__ #ifdef __gfx94__
// Compute friendly // Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std: ...@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
......
...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int StrideB, int StrideB,
int StrideC, int StrideC,
int BatchCount, int BatchCount,
int KBatch,
int n_warmup, int n_warmup,
int n_iter, int n_iter,
uint64_t rotating = 0) uint64_t rotating = 0)
...@@ -147,14 +148,23 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -147,14 +148,23 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device op instances // profile device op instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr; std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
// false branch for multi d dl kernel
argument_ptr = if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{}, {},
...@@ -173,15 +183,13 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -173,15 +183,13 @@ bool profile_gemm_universal_batched_impl(int do_verification,
BatchStrideC, BatchStrideC,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{},
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( float ave_time = invoker_ptr->Run(
...@@ -199,7 +207,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -199,7 +207,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl; << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
...@@ -207,6 +215,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -207,6 +215,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
} }
if(do_verification) if(do_verification)
...@@ -219,7 +228,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -219,7 +228,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
{ {
LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") LogRangeAsType<float>(
std::cout << "c_host: ", c_g_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") std::cout << "c_device: ", c_g_m_n_device_result.mData, ",")
...@@ -229,7 +239,9 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -229,7 +239,9 @@ bool profile_gemm_universal_batched_impl(int do_verification,
} }
else else
{ {
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
} }
} }
...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, ...@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K
<< " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
<< ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops
<< " GB/s, " << best_op_name << std::endl; << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -144,6 +144,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
std::string best_op_name; std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
...@@ -226,6 +227,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -226,6 +227,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(), float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, StreamConfig{nullptr,
...@@ -252,6 +254,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -252,6 +254,7 @@ bool profile_gemm_universal_impl(int do_verification,
if(tflops > best_tflops && ave_time > 1e-10) if(tflops > best_tflops && ave_time > 1e-10)
{ {
best_op_name = op_name; best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
...@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -306,6 +309,9 @@ bool profile_gemm_universal_impl(int do_verification,
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl; << " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass; return pass;
} }
......
...@@ -31,7 +31,7 @@ enum struct GemmDataType ...@@ -31,7 +31,7 @@ enum struct GemmDataType
int profile_batched_gemm_universal(int argc, char* argv[]) int profile_batched_gemm_universal(int argc, char* argv[])
{ {
if(argc != 18 && argc != 21) if(argc != 19 && argc != 22)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
printf("optional:\n"); printf("optional:\n");
printf("arg18: number of warm-up cycles (default 1)\n"); printf("arg19: number of warm-up cycles (default 1)\n");
printf("arg19: number of iterations (default 10)\n"); printf("arg20: number of iterations (default 10)\n");
printf("arg20: memory for rotating buffer (default 0, size in MB)\n"); printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
// clang-format on // clang-format on
exit(1); exit(1);
} }
...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
uint64_t rotating = 0; uint64_t rotating = 0;
if(argc == 21) if(argc == 22)
{ {
n_warmup = std::stoi(argv[18]); n_warmup = std::stoi(argv[19]);
n_iter = std::stoi(argv[19]); n_iter = std::stoi(argv[20]);
rotating = std::stoull(argv[20]) * 1024 * 1024; rotating = std::stoull(argv[21]) * 1024 * 1024;
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const int BatchStrideC = std::stoi(argv[16]); const int BatchStrideC = std::stoi(argv[16]);
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
const int KBatch = std::stoi(argv[18]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t; using F8 = ck::f8_t;
...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) ...@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_, StrideB_,
StrideC_, StrideC_,
BatchCount, BatchCount,
KBatch,
n_warmup, n_warmup,
n_iter, n_iter,
rotating); rotating);
......
...@@ -332,7 +332,7 @@ def main(): ...@@ -332,7 +332,7 @@ def main():
table_name="ck_fmha_bwd_tflops" table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn) tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
conn.close() conn.close()
#compare the results to the baseline if baseline exists #compare the results to the baseline if baseline exists
......
# Currently ck_tile is only built on gfx9 # Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9") if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp)
endif() endif()
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "test_gemm_mem_pipeline_util.hpp" #include "test_gemm_pipeline_util.hpp"
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using F32 = float; using F32 = float;
...@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, ...@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave> std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes);
#include "test_gemm_mem_pipeline_ut_cases.inc" #include "test_gemm_pipeline_ut_cases.inc"
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) TYPED_TEST(TestCkTileGemmPipeline, SmallM)
{ {
std::vector<int> Ms{1, 2, 3, 4, 5, 6}; std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) ...@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) ...@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{127};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) ...@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, Regular) TYPED_TEST(TestCkTileGemmPipeline, Regular)
{ {
std::vector<int> Ms{512}; std::vector<int> Ms{512};
constexpr int N = 1024; constexpr int N = 1024;
...@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular) ...@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this->Run(M, N, K); this->Run(M, N, K);
} }
TYPED_TEST(TestCkTileGemmMemPipeline, NotSupportedArgument) TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
{ {
constexpr int M = 512; constexpr int M = 512;
constexpr int N = 1025; constexpr int N = 1025;
......
...@@ -11,8 +11,13 @@ ...@@ -11,8 +11,13 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
enum struct GemmPipelineType
{
Mem,
Comp
};
template <typename Tuple> template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test class TestCkTileGemmPipeline : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
...@@ -23,6 +28,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -23,6 +28,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
// TODO: expose tile size through test t-param ? // TODO: expose tile size through test t-param ?
struct gemm_args struct gemm_args
...@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -74,8 +80,13 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = std::conditional_t<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
...@@ -85,7 +96,18 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -85,7 +96,18 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline =
std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmPipelineAgBgCrCompV3<
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
...@@ -93,7 +115,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -93,7 +115,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
Traits, Traits,
Scheduler, Scheduler,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment