Commit ba1756e3 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'develop' into feature/add-permute-device-op

parents ea23062c 7c788e10
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
...@@ -149,8 +149,8 @@ int main(int argc, char* argv[]) ...@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 128; ck::index_t M = 120;
ck::index_t N = 1024; ck::index_t N = 1000;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t StrideA = -1; ck::index_t StrideA = -1;
......
...@@ -55,7 +55,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -55,7 +55,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ALayout, ALayout,
...@@ -73,7 +73,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma ...@@ -73,7 +73,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma
Acc0ElementOp, Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
GemmDefault, GemmSpec,
1, 1,
256, 256,
128, // MPerBlock 128, // MPerBlock
...@@ -144,8 +144,8 @@ int main(int argc, char* argv[]) ...@@ -144,8 +144,8 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 1024; ck::index_t M = 1020;
ck::index_t N = 1024; ck::index_t N = 1020;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t BatchCount = 4; ck::index_t BatchCount = 4;
......
...@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2 ...@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
......
...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{
struct ProblemDesc
{
// Overall problem shape
index_t M;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args
index_t StrideA;
index_t StrideB0;
index_t StrideB1;
index_t BatchStrideA;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides;
};
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, ...@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
return is_valid; return is_valid;
} }
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
// grid_size_per_gemm]
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
} // namespace ck } // namespace ck
...@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new = FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) / math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // O_new running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; c_thread_buf(I) = c_new; // O_new
}); });
}); });
......
...@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t ...@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
y0 = vy0.template AsType<half2_t>()[I0]; y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0]; y1 = vy1.template AsType<half2_t>()[I0];
#else #else
asm volatile("\n \ constexpr int32_t m0 = 0x05040100;
v_pack_b32_f16 %0, %1, %2 \n \ constexpr int32_t m1 = 0x07060302;
"
: "=v"(y0) // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
: "v"(x0), "v"(x1)); // -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
asm volatile("\n \ // index is reversed because of little endianness (least significant bits first)
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \ y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
" y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
: "=v"(y1)
: "v"(x0), "v"(x1));
#endif #endif
} }
...@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0, ...@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
// -- -- -- -- -- -- -- -- - - - - // -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88 // index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first) // index is reversed because of little endianness (least significant bits first)
// clang-format off t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast<int32_t>(x1)), "v"(bit_cast<int32_t>(x0)), "s"(m0)); t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast<int32_t>(x3)), "v"(bit_cast<int32_t>(x2)), "s"(m0)); z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z0) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m1)); z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z1) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m2)); t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast<int32_t>(x1)), "v"(bit_cast<int32_t>(x0)), "s"(m3)); t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast<int32_t>(x3)), "v"(bit_cast<int32_t>(x2)), "s"(m3)); z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z2) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m1)); z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z3) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m2));
// clang-format on
y0 = bit_cast<int8x4_t>(z0); y0 = bit_cast<int8x4_t>(z0);
y1 = bit_cast<int8x4_t>(z1); y1 = bit_cast<int8x4_t>(z1);
......
...@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_ ...@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
// clang-format on // clang-format on
>; >;
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances =
std::tuple<
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
Col, Col,
...@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g ...@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16) ...@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this->Run(); this->Run();
} }
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 16},
{256, 64, 160, 64, 16},
{1024, 1024, 80, 80, 16},
{1024, 64, 80, 64, 16},
{4096, 4096, 40, 40, 16},
{4096, 64, 40, 64, 16}};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented // TODO: enable KPadding tests when it is implemented
......
...@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test ...@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
using B1Layout = std::tuple_element_t<6, Tuple>; using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>; using CLayout = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = { std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4}, {256, 256, 128, 128, 4},
{512, 512, 64, 64, 2}, {512, 512, 64, 64, 2},
{512, 512, 128, 128, 2}, {512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1}, {1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1}, {1024, 1024, 128, 128, 1},
}; {256, 256, 160, 160, 4},
{256, 64, 160, 64, 4},
{1024, 1024, 80, 80, 2},
{1024, 64, 80, 64, 2},
{4096, 4096, 40, 40, 1},
{4096, 64, 40, 64, 1}};
bool bench_ = false; bool bench_ = false;
bool verify_ = true; bool verify_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment