Unverified Commit dc70e3e1 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into gridwise_2d

parents 10947a54 8ee36118
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
struct GetReduceCountPerThreadForBlockwiseWelford
{
GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration,
long_index_t reduce_length)
: numBlockTileIteration_{numBlockTileIteration}
{
count_in_last_tile_ = reduce_length % K_BlockTileSize;
};
__device__ index_t operator()(index_t thread_k_cluster_id) const
{
if(count_in_last_tile_ == 0)
return (KThreadSliceSize * numBlockTileIteration_);
else
{
index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize;
index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize;
if(thread_k_cluster_id < num_complete_slice)
return (KThreadSliceSize * numBlockTileIteration_);
else if(thread_k_cluster_id == num_complete_slice)
return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice);
else
return (KThreadSliceSize * (numBlockTileIteration_ - 1));
};
};
index_t numBlockTileIteration_;
index_t count_in_last_tile_;
};
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
struct GetReduceCountPerThreadForMultiblockWelford
{
GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize,
index_t numBlockTileIteration,
long_index_t reduce_length)
: blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration}
{
last_block_reduce_length_ =
reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1);
numBlockTileIterationByLastBlock_ =
(last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
__device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const
{
if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ ||
block_local_id < blkGroupSize_ - 1)
return (KThreadSliceSize * numBlockTileIteration_);
index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize;
if(count_in_last_tile == 0)
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
else
{
index_t num_complete_slice = count_in_last_tile / KThreadSliceSize;
if(thread_k_cluster_id < num_complete_slice)
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
else if(thread_k_cluster_id == num_complete_slice)
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) +
count_in_last_tile);
else
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1));
};
};
index_t blkGroupSize_;
index_t numBlockTileIteration_;
index_t last_block_reduce_length_;
index_t numBlockTileIterationByLastBlock_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -66,6 +66,7 @@ template <index_t BlockSize, ...@@ -66,6 +66,7 @@ template <index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
index_t K1Value,
index_t M1PerThreadM111, index_t M1PerThreadM111,
index_t N1PerThreadN111, index_t N1PerThreadN111,
index_t KPerThread, index_t KPerThread,
...@@ -96,7 +97,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -96,7 +97,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2); static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck { namespace ck {
......
...@@ -75,4 +75,63 @@ struct ThreadwiseWelford ...@@ -75,4 +75,63 @@ struct ThreadwiseWelford
int max_count_; int max_count_;
}; };
template <typename T,
typename SrcMeanVarCountThreadDesc_M_K,
typename DstMeanVarThreadDesc_M,
bool GetActualVariance = false>
struct ThreadwiseWelfordMerge
{
static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
__device__ static void
Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
{
int count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count;
var_a += var_b + delta * delta * count_a * count_b_over_count;
count_a = count;
}
template <typename SrcMeanBufferType,
typename SrcVarBufferType,
typename SrcCountBufferType,
typename DstMeanBufferType,
typename DstVarBufferType,
typename DstCountBufferType>
__device__ static void Run(const SrcMeanBufferType& src_mean_buf,
const SrcVarBufferType& src_var_buf,
const SrcCountBufferType& src_count_buf,
DstMeanBufferType& dst_mean_buf,
DstVarBufferType& dst_var_buf,
DstCountBufferType& dst_count_buf)
{
static_for<0, src_length_m, 1>{}([&](auto iM) {
static_for<0, src_length_k, 1>{}([&](auto iK) {
constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Merge(dst_mean_buf(iM),
dst_var_buf(iM),
dst_count_buf(iM),
src_mean_buf[Number<src_offset>{}],
src_var_buf[Number<src_offset>{}],
src_count_buf[Number<src_offset>{}]);
});
if constexpr(GetActualVariance)
{
dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
};
});
};
};
} // namespace ck } // namespace ck
...@@ -594,6 +594,7 @@ struct XdlopsGemm ...@@ -594,6 +594,7 @@ struct XdlopsGemm
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>; using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
...@@ -822,6 +823,16 @@ struct XdlopsGemm ...@@ -822,6 +823,16 @@ struct XdlopsGemm
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{}; static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
...@@ -6,6 +6,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -6,6 +6,7 @@ function(add_instance_library INSTANCE_NAME)
clang_tidy_check(${INSTANCE_NAME}) clang_tidy_check(${INSTANCE_NAME})
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
file(GLOB dir_list LIST_DIRECTORIES true *) file(GLOB dir_list LIST_DIRECTORIES true *)
set(CK_DEVICE_INSTANCES) set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list}) FOREACH(subdir_path ${dir_list})
......
...@@ -34,27 +34,27 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< ...@@ -34,27 +34,27 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 1, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 16, 16, true, 7, 1> DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 2, 16, true, 7, 1>
// clang-format on // clang-format on
>; >;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment