"docs/source/en/api/schedulers/repaint.md" did not exist on "21e61eb3a9d16a46245bd284fea3aa19e66772f5"
Commit 25e751f0 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan Committed by hsadasiv
Browse files

updating files for splitK GEMV; disclaimer: for accuracy testing in example/...

updating files for splitK GEMV; disclaimer: for accuracy testing in example/ remember to run the kernel just once because of memset happening before atomicadd
parent 57d0ea67
add_custom_target(example_splitK_gemv)
add_example_executable(example_splitK_gemv_fp16 splitK_gemv_fp16.cpp)
add_dependencies(example_splitK_gemv
example_splitK_gemv_fp16)
\ No newline at end of file
# Instructions for ```example_gemm_xdl```
## Run ```example_gemm_xdl```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
struct ProblemSize_gemv final
{
// ck::index_t M = 1; // 3840;
// ck::index_t N = 128;
// ck::index_t K = 128;
// ck::index_t stride_A = K;
// ck::index_t stride_B = K;
// ck::index_t stride_C = N;
// ck::index_t k_batch = 1;
ck::index_t M = 1; // 3840;
ck::index_t N = 1104;
ck::index_t K = 4608;
ck::index_t stride_A = K;
ck::index_t stride_B = K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
bool run_gemv(const ProblemSize_gemv& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemv = DeviceGemvInstance{};
auto invoker = gemv.MakeInvoker();
auto argument = gemv.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!gemv.IsSupportedArgument(argument))
{
std::cerr << gemv.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemv = ReferenceGemmInstance{};
auto ref_invoker = ref_gemv.MakeInvoker();
auto ref_argument = ref_gemv.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
return true;
}
bool run_gemv_example(int argc, char* argv[])
{
ProblemSize_gemv problem_size;
// problem_size.M = 1;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: KBatch\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_gemv(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_splitK_gemv.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 8 //2,4,8
#define K0 4 //1,2,3,4...32 --> 423ms vs 282ms (k0=4)
#define N1 2 //2,4,8
#define B 64 //64
// clang-format off
using DeviceGemvInstance = ck::tensor_operation::device::deviceGemvDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 128, 8, 8, 1, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1, 1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, 1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1, 1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_splitK_gemv_example.inc"
int main(int argc, char* argv[]) { return !run_gemv_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemv : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch=1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemv.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_splitK_gemv.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceGemvDl : public DeviceGemv<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// GridwiseGemv
using GridwiseGemv =
GridwiseGemvDl_km_kn_mn<BlockSize,
ADataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
GemmSpec,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerThread,
NPerThread,
KPerThread,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseGemv::DefaultBlock2CTileMap;
using Argument = typename GridwiseGemv::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = GridwiseGemv::CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseGemv::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemv::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(karg.k_batch > 1)
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
}
else
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
else
{
const auto kernel = kernel_gemv_dl_v1r3<GridwiseGemv,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg, b2c_map);
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// //
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
return GridwiseGemv::CheckValidity(arg);
}
else
{
return false;
}
}
// //
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseGemv::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N),
GridwiseGemv::CalculateKPadded(K, KBatch),
GridwiseGemv::CalculateK0(K, KBatch),
KBatch}; // //
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseGemv::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N),
GridwiseGemv::CalculateKPadded(K, KBatch),
GridwiseGemv::CalculateK0(K, KBatch),
KBatch); // //
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "deviceGemvDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerThread << ", "
<< NPerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -655,6 +655,16 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -655,6 +655,16 @@ struct BlockToCTileMap_3DGrid_KSplit
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
} }
//HS: Map 1D block-id to 3D tuple (M,N,K)
__host__ __device__ inline constexpr auto convert_1D_block_idx_to_3D_tuple(
const index_t& block_1d_id, const index_t& N, const index_t& k_batch) const
{
const auto Ndim= math::integer_divide_ceil(N, NPerBlock);
return make_tuple(((block_1d_id) / (k_batch * Ndim)),
(((block_1d_id) / k_batch) % Ndim),
(block_1d_id) % k_batch); // returns 3D tuple as (Mid,Nid,Kid)
}
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const const CTileDim& /* c_tile_dim */) const
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment