"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "0dbbf9c362b1a1de7d351a8a3aa573c7c0cc62b8"
Unverified Commit 29deceb6 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #18 from ROCmSoftwarePlatform/merge-from-public

Merge from public
parents 91c1d147 c997bbf6
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations utility)
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
......
......@@ -694,8 +694,8 @@ pipeline {
description: "Use the CK build to verify hipTensor build and tests (default: ON)")
string(
name: 'hipTensor_branch',
defaultValue: 'mainline',
description: 'Specify which branch of hipTensor to use (default: mainline)')
defaultValue: 'develop',
description: 'Specify which branch of hipTensor to use (default: develop)')
booleanParam(
name: "USE_SCCACHE",
defaultValue: true,
......@@ -759,7 +759,7 @@ pipeline {
}
agent{ label rocmnode("gfx908 || gfx90a") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps{
......
add_executable(client_gemm gemm.cpp)
target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_custom_target(client_gemm_fastgelu_examples)
add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp)
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_fastgelu gemm_fastgelu.cpp)
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
client_gemm_fastgelu)
......@@ -15,13 +15,13 @@ add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu clie
add_custom_target(client_gemm_fastgelu_generic_examples)
add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations)
add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp)
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp)
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations)
add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic)
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp)
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp)
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp)
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations)
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
add_executable(client_layernorm2d layernorm2d.cpp)
target_link_libraries(client_layernorm2d PRIVATE composable_kernel::device_operations)
add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp)
target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations)
add_executable(client_layernorm4d_fwd layernorm4d_fwd.cpp)
target_link_libraries(client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations)
......@@ -7,10 +7,10 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
......@@ -57,14 +57,14 @@ int main(int argc, char* argv[])
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim>;
using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
ck::index_t N = 256;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t C = 8;
std::vector<ck::index_t> strideXY = {H * W * C, W * C, C, 1};
std::vector<ck::index_t> strideGammaBeta = {0, W * C, C, 1};
std::vector<ck::index_t> strideSaveMeanInvStd = {1};
SimpleDeviceMem x_device_buf(sizeof(XDataType) * N * H * W * C);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * H * W * C);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * H * W * C);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * N * H * W * C);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer({N, H, W, C}, // lengths
strideXY, // xStrides
strideGammaBeta, // gammaStrides
strideGammaBeta, // betaStrides
strideXY, // yStrides
strideSaveMeanInvStd, // save_mean Strides
strideSaveMeanInvStd, // save_inv_std Strides
{1, 2, 3}, // reduceDims
1e-4,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte =
sizeof(XDataType) * N * H * W * C + sizeof(GammaDataType) * H * W * C +
sizeof(BetaDataType) * H * W * C + sizeof(YDataType) * N * H * W * C;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * N * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr =
op_ptr->MakeArgumentPointer({N, H, W, C}, // lengths
strideXY, // xStrides
strideGammaBeta, // gammaStrides
strideGammaBeta, // betaStrides
strideXY, // yStrides
strideSaveMeanInvStd, // save_mean Strides
strideSaveMeanInvStd, // save_inv_std Strides
{1, 2, 3}, // reduceDims
1e-4,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
add_executable(client_softmax4d softmax4d.cpp)
target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_operations)
target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations)
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations)
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_operations)
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
......@@ -100,18 +100,18 @@ int main()
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......
......@@ -71,18 +71,18 @@ int main()
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......
add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp)
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
add_executable(client_gemm_quantization gemm_quantization.cpp)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
endif()
......@@ -80,7 +80,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
WeiLayout,
......
......@@ -78,18 +78,18 @@ int main(int argc, char* argv[])
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<BiasDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<BiasDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
......
......@@ -83,7 +83,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
WeiLayout,
......
......@@ -79,18 +79,18 @@ int main(int argc, char* argv[])
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<BiasDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<BiasDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
......
......@@ -76,19 +76,19 @@ int main(int argc, char* argv[])
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<RequantScaleLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<RequantScaleDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<RequantScaleLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<RequantScaleDataType>,
OutDataType,
PassThrough,
PassThrough,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment