"src/git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "b239346701bd8d9cbc2ba1a2f5053cb1e1d671b5"
Unverified Commit cec69bc3 authored by JD's avatar JD Committed by GitHub
Browse files

Add host API (#220)



* Add host API

* manually rebase on develop

* clean

* manually rebase on develop

* exclude tests from all target

* address review comments

* update client app name

* fix missing lib name

* clang-format update

* refactor

* refactor

* refactor

* refactor

* refactor

* fix test issue

* refactor

* refactor

* refactor

* upate cmake and readme
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0f912e20
...@@ -149,9 +149,9 @@ int main(int argc, char* argv[]) ...@@ -149,9 +149,9 @@ int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce; using namespace ck::host_reduce;
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// Pool shape // Pool shape
ck::index_t N = 128; ck::index_t N = 128;
...@@ -171,13 +171,13 @@ int main(int argc, char* argv[]) ...@@ -171,13 +171,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 16) else if(argc == 16)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
C = std::stoi(argv[5]); C = std::stoi(argv[5]);
...@@ -196,7 +196,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
...@@ -271,7 +271,7 @@ int main(int argc, char* argv[]) ...@@ -271,7 +271,7 @@ int main(int argc, char* argv[])
"not support this problem"); "not support this problem");
} }
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X;
......
...@@ -105,9 +105,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -105,9 +105,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -125,13 +125,13 @@ int main(int argc, char* argv[]) ...@@ -125,13 +125,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -145,7 +145,7 @@ int main(int argc, char* argv[]) ...@@ -145,7 +145,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -219,7 +219,7 @@ int main(int argc, char* argv[]) ...@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -60,21 +60,21 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -60,21 +60,21 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
if(argc == 4) if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0); exit(0);
} }
...@@ -202,7 +202,7 @@ int main(int argc, char* argv[]) ...@@ -202,7 +202,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......
...@@ -58,9 +58,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -58,9 +58,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 1; bool do_verification = true;
int init_method = 1; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -79,13 +79,13 @@ int main(int argc, char* argv[]) ...@@ -79,13 +79,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -99,7 +99,7 @@ int main(int argc, char* argv[]) ...@@ -99,7 +99,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -192,30 +192,13 @@ int main(int argc, char* argv[]) ...@@ -192,30 +192,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// warm up
invoker.Run(argument);
// timing
float total_time = 0;
for(int i = 0; i < nrepeat; ++i)
{
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); d0_device_buf.SetZero();
d1_device_buf.SetZero(); d1_device_buf.SetZero();
KernelTimer timer; // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
timer.Start(); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
invoker.Run(argument);
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -87,7 +87,7 @@ void print_use_msg() ...@@ -87,7 +87,7 @@ void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
<< "arg3: run kernel # of times (>1)\n" << "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
...@@ -165,9 +165,9 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial) ...@@ -165,9 +165,9 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = true;
int init_method = 0; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParams params; ck::utils::conv::ConvParams params;
...@@ -177,13 +177,13 @@ int main(int argc, char* argv[]) ...@@ -177,13 +177,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc > 4) else if(argc > 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); num_dim_spatial = std::stoi(argv[4]);
// check args number // check args number
int conv_args = 3 + num_dim_spatial * 6; int conv_args = 3 + num_dim_spatial * 6;
...@@ -284,7 +284,7 @@ int main(int argc, char* argv[]) ...@@ -284,7 +284,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"); "not support this Conv problem");
} }
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = ck::utils::conv::get_flops( std::size_t flop = ck::utils::conv::get_flops(
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
......
...@@ -57,9 +57,9 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: ...@@ -57,9 +57,9 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 1; bool do_verification = true;
int init_method = 1; int init_method = 1;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -80,13 +80,13 @@ int main(int argc, char* argv[]) ...@@ -80,13 +80,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 11)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -102,7 +102,7 @@ int main(int argc, char* argv[]) ...@@ -102,7 +102,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n");
exit(0); exit(0);
} }
...@@ -204,30 +204,13 @@ int main(int argc, char* argv[]) ...@@ -204,30 +204,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
// warm up
invoker.Run(argument);
// timing
float total_time = 0;
for(int i = 0; i < nrepeat; ++i)
{
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); d0_device_buf.SetZero();
d1_device_buf.SetZero(); d1_device_buf.SetZero();
KernelTimer timer; // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
timer.Start(); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
invoker.Run(argument);
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K +
......
#pragma once
// "_PACKAGE_" to avoid name contentions: the macros like
// HIP_VERSION_MAJOR are defined in HIP_VERSION.h.
// clang-format off
#define CK_HIP_PACKAGE_VERSION_MAJOR @CK_HIP_VERSION_MAJOR@
#define CK_HIP_PACKAGE_VERSION_MINOR @CK_HIP_VERSION_MINOR@
#define CK_HIP_PACKAGE_VERSION_PATCH @CK_HIP_VERSION_PATCH@
// clang-format on
#ifndef CK_HIP_PACKAGE_VERSION_MAJOR
#define CK_HIP_PACKAGE_VERSION_MAJOR 0
#endif
#ifndef CK_HIP_PACKAGE_VERSION_MINOR
#define CK_HIP_PACKAGE_VERSION_MINOR 0
#endif
#ifndef CK_HIP_PACKAGE_VERSION_PATCH
#define CK_HIP_PACKAGE_VERSION_PATCH 0
#endif
// 3 decimal digits for major and minor, 6 digits for patch number.
// Max number is 999,999,999999 == 0xE8,D4A5,0FFF that fits into 64-bit math.
#if CK_HIP_PACKAGE_VERSION_MAJOR > 999 || CK_HIP_PACKAGE_VERSION_MAJOR > 999 || \
CK_HIP_PACKAGE_VERSION_PATCH > 999999
#error "Too big HIP version number(s)"
#endif
#define CK_HIP_PACKAGE_VERSION_FLAT \
((CK_HIP_PACKAGE_VERSION_MAJOR * 1000ULL + CK_HIP_PACKAGE_VERSION_MINOR) * 1000000 + \
CK_HIP_PACKAGE_VERSION_PATCH)
#pragma once
#cmakedefine01 CK_TIME_KERNEL
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
struct StreamConfig
{
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
};
#ifndef DEVICE_BASE_HPP #pragma once
#define DEVICE_BASE_HPP
#include <string> #include <string>
#include "stream_config.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -22,7 +23,10 @@ struct BaseInvoker ...@@ -22,7 +23,10 @@ struct BaseInvoker
BaseInvoker(const BaseInvoker&) = default; BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, int = 1) = 0; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {} virtual ~BaseInvoker() {}
}; };
...@@ -33,8 +37,8 @@ struct BaseOperator ...@@ -33,8 +37,8 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default; BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) = 0; virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const = 0; virtual std::string GetTypeString() const { return ""; }
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
...@@ -42,4 +46,3 @@ struct BaseOperator ...@@ -42,4 +46,3 @@ struct BaseOperator
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -693,7 +693,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -729,6 +729,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -729,6 +729,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
...@@ -748,7 +749,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -748,7 +749,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
launch_kernel(kernel, elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -788,7 +791,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -788,7 +791,9 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
launch_kernel(kernel, elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -810,13 +815,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi ...@@ -810,13 +815,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
return 0; return elapsed_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -428,7 +428,7 @@ struct DeviceBatchedGemmXdl ...@@ -428,7 +428,7 @@ struct DeviceBatchedGemmXdl
{ {
using Argument = DeviceBatchedGemmXdl::Argument; using Argument = DeviceBatchedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl ...@@ -477,8 +477,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl ...@@ -511,8 +511,8 @@ struct DeviceBatchedGemmXdl
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl ...@@ -534,9 +534,10 @@ struct DeviceBatchedGemmXdl
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -415,9 +415,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -437,35 +438,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -437,35 +438,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0)
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset( hipGetErrorString(hipMemset(
arg.p_c_grid_, arg.p_c_grid_,
0, 0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType)));
launch_kernel(kernel, launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -479,7 +459,6 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -479,7 +459,6 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -560,9 +539,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -531,7 +531,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
...@@ -602,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -602,8 +602,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
true>; true>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -635,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -635,8 +635,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
false>; false>;
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -655,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -655,9 +655,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -642,7 +642,7 @@ struct ...@@ -642,7 +642,7 @@ struct
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -727,8 +727,8 @@ struct ...@@ -727,8 +727,8 @@ struct
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -771,8 +771,8 @@ struct ...@@ -771,8 +771,8 @@ struct
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -795,9 +795,10 @@ struct ...@@ -795,9 +795,10 @@ struct
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -605,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -605,7 +605,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -684,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -684,8 +684,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -723,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -723,8 +723,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -745,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -745,9 +745,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -568,7 +568,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -663,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -663,8 +663,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -697,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -697,8 +697,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config,
kernel, kernel,
nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -717,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -717,9 +717,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -450,7 +450,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
...@@ -498,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -498,8 +498,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -529,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -529,8 +529,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -549,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -549,9 +549,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
return ave_time; return ave_time;
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto naive_conv3d_fwd = const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType, ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
...@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -103,8 +103,8 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation>; OutElementwiseOperation>;
float ave_time = launch_and_time_kernel(naive_conv3d_fwd, float ave_time = launch_and_time_kernel(stream_config,
nrepeat, naive_conv3d_fwd,
dim3(256), dim3(256),
dim3(256), dim3(256),
0, 0,
...@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W ...@@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -438,7 +438,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
...@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -487,8 +487,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -522,8 +522,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -547,9 +547,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment