Commit edc494df authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent 00331ee4
...@@ -48,10 +48,10 @@ using B0Layout = Col; ...@@ -48,10 +48,10 @@ using B0Layout = Col;
using B1Layout = Row; using B1Layout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -113,14 +113,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X ...@@ -113,14 +113,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
ADataType, ADataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host:: using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, B1ElementOp, CElementOp>; B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -179,15 +184,15 @@ int main(int argc, char* argv[]) ...@@ -179,15 +184,15 @@ int main(int argc, char* argv[])
BatchCount = std::stoi(argv[8]); BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]); StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]); StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]); StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]); StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]); BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]); BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]); BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]); BatchStrideC = std::stoi(argv[16]);
} }
else else
{ {
...@@ -282,35 +287,36 @@ int main(int argc, char* argv[]) ...@@ -282,35 +287,36 @@ int main(int argc, char* argv[])
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), auto argument =
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
M, static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
N, M,
K, N,
O, K,
BatchCount, O,
StrideA, BatchCount,
StrideB0, StrideA,
StrideB1, StrideB0,
StrideC, StrideB1,
BatchStrideA, StrideC,
BatchStrideB0, BatchStrideA,
BatchStrideB1, BatchStrideB0,
BatchStrideC, BatchStrideB1,
a_element_op, BatchStrideC,
b0_element_op, a_element_op,
b1_element_op, b0_element_op,
c_element_op); b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) ...@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
return transform_tensor_descriptor( return transform_tensor_descriptor(
TileDesc_K0_MN_K1{}, TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})), make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple( make_unmerge_transform(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))), make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
} }
...@@ -694,7 +694,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -694,7 +694,7 @@ struct BlockwiseGemmXdlops_v2
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto __device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>) CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
...@@ -723,9 +723,8 @@ struct BlockwiseGemmXdlops_v2 ...@@ -723,9 +723,8 @@ struct BlockwiseGemmXdlops_v2
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2( __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 a_origin = CalculateAThreadOriginDataIndex(), Tuple4 b_origin = CalculateBThreadOriginDataIndex())
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin) : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{ {
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
...@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2
"wrong!"); "wrong!");
} }
__host__ __device__ BlockwiseGemmXdlops_v2( __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{ {
} }
......
...@@ -38,22 +38,23 @@ __global__ void ...@@ -38,22 +38,23 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b1_grid,
const AElementwiseOperation a_element_op, FloatC* __restrict__ p_c_grid,
const BElementwiseOperation b_element_op, const AElementwiseOperation a_element_op,
const B1ElementwiseOperation b1_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const B1ElementwiseOperation b1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const CElementwiseOperation c_element_op,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const Block2CTileMap block_2_ctile_map, c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t batch_count, const Block2CTileMap block_2_ctile_map,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -162,17 +163,17 @@ template <typename ALayout, ...@@ -162,17 +163,17 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout, struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout,
BLayout, BLayout,
B1Layout, B1Layout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
...@@ -405,12 +406,12 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -405,12 +406,12 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{ {
const auto B1K0 = KRaw / B1K1; const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b1_grid_desc_nraw_kraw, b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1; return b1_grid_desc_bk0_n_bk1;
} }
...@@ -426,16 +427,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -426,16 +427,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b1_grid_desc_n_k, b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1; return b1_grid_desc_bk0_n_bk1;
} }
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -537,9 +537,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -537,9 +537,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
}; };
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmGemm_Xdl_CShuffle<
...@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a, p_b, p_b1, p_c, MRaw,
p_b, NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
p_b1, StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
p_c, BatchStrideB1, BatchStrideC, a_element_op, b_element_op, b1_element_op,
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA,
StrideB,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
b1_element_op,
c_element_op}; c_element_op};
} }
......
...@@ -181,8 +181,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -181,8 +181,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
...@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB), return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && Gemm1N % Gemm1NPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
...@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); const auto a_block_reset_copy_step =
const auto b_block_reset_copy_step = make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
// Only supports LoopScheduler::Default // Only supports LoopScheduler::Default
......
...@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4
{ {
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
// apply type convert // apply type convert
src_tmp_vector.template AsType<SrcData>()(i) = src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
src_buf[Number<src_offset>{}];
}); });
} }
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
...@@ -1206,7 +1206,8 @@ template <typename SrcData, ...@@ -1206,7 +1206,8 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic struct ThreadwiseTensorSliceTransfer_StaticToStatic
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
"wrong! Not divisible"); "wrong! Not divisible");
} }
template <typename SrcSliceOriginIdx, typename DstSliceOriginIdx, typename SrcBuffer, typename DstBuffer> template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
...@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
}); });
}); });
} }
}; };
} // namespace ck } // namespace ck
...@@ -739,13 +739,15 @@ struct XdlopsGemm ...@@ -739,13 +739,15 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr (!TransposeC) if constexpr(!TransposeC)
{ {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], p_c_thread);
} }
else else
{ {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_b_wave[k], p_a_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], p_a_wave[k], p_c_thread);
} }
}); });
} }
......
...@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; AccDataType v_c;
......
...@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal ...@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal
T operator()(Ts... Xs) const T operator()(Ts... Xs) const
{ {
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}}; std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
size_t start_dim = dims.size() - NumEffectiveDim; size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true; bool pred = true;
for (size_t i = start_dim + 1; i < dims.size(); i++) { for(size_t i = start_dim + 1; i < dims.size(); i++)
{
pred &= (dims[start_dim] == dims[i]); pred &= (dims[start_dim] == dims[i]);
} }
return pred ? value : T{0}; return pred ? value : T{0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment