Unverified Commit 500fa995 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Clean up conv example, Instances, profiler and test (#324)

* convnd_fwd fp16 example

* update example

* update example

* update instance

* updating refernce conv

* update reference conv

* update conv fwd profiler

* update conv 1d and 3d instance

* update include path

* clean

* update profiler for conv bwd data and weight

* update conv bwd weight

* clean

* update conv example

* update profiler for conv bwd weight

* update ckprofiler for conv bwd data

* fix reference conv bwd data bug; update conv bwd data test

* update examples

* fix initialization issue

* update test for conv fwd

* clean

* clean

* remove test case too sensitive to error threshhold

* fix test

* clean

* fix build

* adding conv multiple d

* adding conv multiple D

* add matrix padder

* add gemm padding to convnd

* adding group conv

* update gemm multi-d

* refactor

* refactor

* refactor

* clean

* clean

* refactor

* refactor

* reorg

* add ds

* add bias

* clean

* add G

* adding group

* adding group

* adding group

* update Tensor

* clean

* update example

* update DeviceGemmMultipleD_Xdl_CShuffle

* update conv bwd-data and bwd-weight

* upate contraction example

* update gemm and batch gemm with e permute

* fix example build

* instance for grouped conv1d

* update example

* adding group conv instance

* update gemm bilinear instance

* update gemm+add+add+fastgelu instance

* update profiler

* update profiler

* update test

* update test and client example

* clean

* add grouped conv into profiler

* update profiler

* clean

* add test grouped conv, update all conv test to gtest

* update test
parent 85978e02
...@@ -19,49 +19,53 @@ namespace tensor_operation { ...@@ -19,49 +19,53 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row, Row,
Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_TUPLE, F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Col, Col,
Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_TUPLE, F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row, Row,
Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_TUPLE, F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col, Col,
Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F16_TUPLE, F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -70,7 +74,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( ...@@ -70,7 +74,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
// GEMM + Bilinear // GEMM + Bilinear
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DLayout,
typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DDataType, typename DDataType,
...@@ -78,7 +83,8 @@ template <typename ALayout, ...@@ -78,7 +83,8 @@ template <typename ALayout,
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout, ALayout,
BLayout, BLayout,
DELayout, ck::Tuple<DLayout>,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
...@@ -89,7 +95,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -89,7 +95,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
using DeviceOp = DeviceGemmMultipleD<ALayout, using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DELayout, ck::Tuple<DLayout>,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
...@@ -106,24 +113,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -106,24 +113,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>) is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>) is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{ {
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>) is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{ {
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>) is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{ {
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>) is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{ {
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
op_ptrs);
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv1d forward, GNWC/GKXC/GNWK
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
GKXC,
Empty_Tuple,
GNWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/KYXGC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
KYXGC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
Empty_Tuple,
OutLayout,
InDataType,
WeiDataType,
Empty_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
Empty_Tuple,
OutLayout,
InDataType,
WeiDataType,
Empty_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, KYXGC> && is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
// no instance
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
// no instance
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
// no instance
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -16,15 +16,14 @@ namespace tensor_operation { ...@@ -16,15 +16,14 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using DsType = Tuple<>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row, Row,
Empty_Tuple,
Row, Row,
F16, F16,
F16, F16,
DsType, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -33,10 +32,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( ...@@ -33,10 +32,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col, Col,
Empty_Tuple,
Row, Row,
F16, F16,
F16, F16,
DsType, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -45,10 +45,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( ...@@ -45,10 +45,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row, Row,
Empty_Tuple,
Row, Row,
F16, F16,
F16, F16,
DsType, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -57,10 +58,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( ...@@ -57,10 +58,11 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col, Col,
Empty_Tuple,
Row, Row,
F16, F16,
F16, F16,
DsType, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -68,18 +70,18 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( ...@@ -68,18 +70,18 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType> typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout, ALayout,
BLayout, BLayout,
CLayout, Empty_Tuple,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, Empty_Tuple,
EDataType, EDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -87,10 +89,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -87,10 +89,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
using DeviceOp = DeviceGroupedGemm<ALayout, using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout, BLayout,
CLayout, Empty_Tuple,
ELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, Empty_Tuple,
EDataType, EDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -104,22 +107,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -104,22 +107,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/host_utility/io.hpp"
namespace ck { namespace ck {
namespace utils { namespace utils {
...@@ -194,10 +196,3 @@ check_err(const std::vector<T>& out, ...@@ -194,10 +196,3 @@ check_err(const std::vector<T>& out,
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
return os;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <functional>
#include <iterator>
#include <numeric>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/op_instance_engine.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough,
element_wise::PassThrough>;
namespace instance {
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace instance
namespace instance {
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace instance
namespace instance {
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace utils {
namespace conv {
using DeviceConvFwdNoOpPtr =
ck::tensor_operation::device::DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
/**
* @brief Calculate number of FLOPs for Convolution
*
* @param[in] N Batch size.
* @param[in] C Number of input channels.
* @param[in] K Number of output channels.
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
* @param[in] output_spatial_lengths Convolution output spatial dimensions
* lengths.
*
* @return The number of flops.
*/
std::size_t get_flops(ck::index_t N,
ck::index_t C,
ck::index_t K,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths);
/**
* @brief Calculate number of bytes read/write by convolution algorithm.
*
* @param[in] N Batch size.
* @param[in] C Number of input channels.
* @param[in] K Number of output channels.
* @param[in] input_spatial_lengths Input spatial dimensions lengths.
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
* @param[in] output_spatial_lengths Output spatial dimensions lengths
*
* @tparam InDataType Input tensor data type.
* @tparam WeiDataType Weights tensor data type.
* @tparam OutDataType Output tensor data type.
*
* @return The number of used bytes.
*/
template <typename InDataType = float,
typename WeiDataType = InDataType,
typename OutDataType = InDataType>
std::size_t get_btype(ck::index_t N,
ck::index_t C,
ck::index_t K,
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths)
{
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
return sizeof(InDataType) * (N * C *
std::accumulate(std::begin(input_spatial_lengths),
std::end(input_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(WeiDataType) * (K * C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (N * K *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
struct ConvParams
{
ConvParams();
ConvParams(ck::index_t n_dim,
ck::index_t n_batch,
ck::index_t n_out_channels,
ck::index_t n_in_channels,
const std::vector<ck::index_t>& filters_len,
const std::vector<ck::index_t>& input_len,
const std::vector<ck::index_t>& strides,
const std::vector<ck::index_t>& dilations,
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_;
ck::index_t N_;
ck::index_t K_;
ck::index_t C_;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
std::vector<ck::index_t> GetOutputSpatialLengths() const;
};
ConvParams parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]);
/**
* @brief Gets the host tensor descriptor.
*
* @param[in] dims The tensor dimensions lengths. Always in NCHW format.
* @param[in] layout The tensor data layout.
*
* @tparam TensorLayout Layout type.
*
* @return The host tensor descriptor object.
*/
template <typename TensorLayout>
HostTensorDescriptor get_host_tensor_descriptor(const std::vector<std::size_t>& dims,
const TensorLayout& layout)
{
std::size_t C = dims[1];
// 1D
if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKW>::value)
{
return HostTensorDescriptor(dims, std::vector<std::size_t>{C * dims[2], dims[2], 1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NWK>::value)
{
return HostTensorDescriptor(dims, std::vector<std::size_t>{C * dims[2], 1, C});
}
// 2D
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCHW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCYX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKHW>::value)
{
return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KYXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWK>::value)
{
return HostTensorDescriptor(
dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C});
}
// 3D
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCDHW>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCZYX>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKDHW>::value)
{
return HostTensorDescriptor(dims,
std::vector<std::size_t>{C * dims[2] * dims[3] * dims[4],
dims[2] * dims[3] * dims[4],
dims[3] * dims[4],
dims[4],
1});
}
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KZYXC>::value ||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NDHWK>::value)
{
return HostTensorDescriptor(
dims,
std::vector<std::size_t>{
C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C});
}
std::stringstream err_msg;
err_msg << "Unsupported data layout provided: " << layout << "!";
throw std::runtime_error(err_msg.str());
}
HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2);
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2);
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2);
template <ck::index_t NDim,
typename InDataType = float,
typename WeiDataType = float,
typename OutDataType = float>
void run_reference_convolution_forward(const ConvParams& params,
const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weights,
Tensor<OutDataType>& output)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
OutDataType,
PassThrough,
PassThrough,
PassThrough,
NDim>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(input,
weights,
output,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
struct ConvolutionFwdInstances;
template <>
struct ConvolutionFwdInstances<float, float, float>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <>
struct ConvolutionFwdInstances<half_t, half_t, half_t>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
return conv_ptrs;
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <>
struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <>
struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
{
template <int NumDimSpatial,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
static std::vector<DeviceConvFwdNoOpPtr> Get()
{
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
}
return conv_ptrs;
}
};
template <typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout = ck::tensor_layout::convolution::NHWC,
typename WeiLayout = ck::tensor_layout::convolution::KYXC,
typename OutLayout = ck::tensor_layout::convolution::NHWK,
typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
typename InputInitFun = FillUniformDistribution<InDataType>,
typename WeightsInitFun = FillUniformDistribution<WeiDataType>>
class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>
{
using DeviceConvFwdOp = tensor_operation::device::
DeviceConvFwd<InElementwiseOp, WeiElementwiseOp, OutElementwiseOp>;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
using DeviceBuffers = std::vector<DeviceMemPtr>;
using BaseType = ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>;
template <typename T>
using TensorPtr = std::unique_ptr<Tensor<T>>;
using InTensorsTuple = std::tuple<TensorPtr<InDataType>, TensorPtr<WeiDataType>>;
public:
ConvFwdOpInstance() = delete;
ConvFwdOpInstance(const ConvFwdOpInstance&) = default;
ConvFwdOpInstance& operator=(const ConvFwdOpInstance&) = default;
ConvFwdOpInstance(const ConvParams& params,
bool do_init = true,
const InputInitFun& input_init_f = InputInitFun(),
const WeightsInitFun& weights_init_f = WeightsInitFun())
: BaseType(),
params_{params},
output_spatial_lengths_{params.GetOutputSpatialLengths()},
do_init_{do_init},
input_init_f_{input_init_f},
weights_init_f_{weights_init_f}
{
}
virtual ~ConvFwdOpInstance() override{};
virtual InTensorsTuple GetInputTensors() const override
{
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params_.N_),
static_cast<std::size_t>(params_.C_)};
input_dims.insert(std::end(input_dims),
std::begin(params_.input_spatial_lengths_),
std::end(params_.input_spatial_lengths_));
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params_.K_),
static_cast<std::size_t>(params_.C_)};
filter_dims.insert(std::end(filter_dims),
std::begin(params_.filter_spatial_lengths_),
std::end(params_.filter_spatial_lengths_));
auto input = std::make_unique<Tensor<InDataType>>(
get_host_tensor_descriptor(input_dims, InLayout{}));
auto weights = std::make_unique<Tensor<WeiDataType>>(
get_host_tensor_descriptor(filter_dims, WeiLayout{}));
if(do_init_)
{
input_init_f_(input->begin(), input->end());
weights_init_f_(weights->begin(), weights->end());
}
return std::make_tuple(std::move(input), std::move(weights));
}
virtual TensorPtr<OutDataType> GetOutputTensor() const override
{
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params_.N_),
static_cast<std::size_t>(params_.K_)};
output_dims.insert(std::end(output_dims),
std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_));
auto output = std::make_unique<Tensor<OutDataType>>(
get_host_tensor_descriptor(output_dims, OutLayout{}));
if(do_init_)
{
std::fill(output->begin(), output->end(), OutDataType(0.f));
}
return output;
}
virtual std::unique_ptr<tensor_operation::device::BaseInvoker>
MakeInvokerPointer(tensor_operation::device::BaseOperator* op_ptr) const override
{
static_assert(
std::is_same_v<InElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<OutElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<WeiElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
auto conv_ptr = dynamic_cast<DeviceConvFwdOp*>(op_ptr);
if(!conv_ptr)
{
throw std::runtime_error(
"[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!");
}
return conv_ptr->MakeInvokerPointer();
}
virtual std::unique_ptr<tensor_operation::device::BaseArgument>
MakeArgumentPointer(tensor_operation::device::BaseOperator* op_ptr,
const DeviceBuffers& in_device_buffers,
const DeviceMemPtr& out_device_buffer) const override
{
static_assert(
std::is_same_v<InElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<OutElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
static_assert(
std::is_same_v<WeiElementwiseOp, ck::tensor_operation::element_wise::PassThrough>);
auto conv_ptr = dynamic_cast<DeviceConvFwdOp*>(op_ptr);
if(!conv_ptr)
{
throw std::runtime_error(
"[ConvFwdOpInstance]: couldn't cast op_ptr to DeviceConvFwdNoOpPtr type!");
}
return conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buffers[0]->GetDeviceBuffer()),
static_cast<WeiDataType*>(in_device_buffers[1]->GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buffer->GetDeviceBuffer()),
params_.N_,
params_.K_,
params_.C_,
params_.input_spatial_lengths_,
params_.filter_spatial_lengths_,
output_spatial_lengths_,
params_.conv_filter_strides_,
params_.conv_filter_dilations_,
params_.input_left_pads_,
params_.input_right_pads_,
InElementwiseOp{},
WeiElementwiseOp{},
OutElementwiseOp{});
}
virtual std::size_t GetFlops() const override
{
return get_flops(params_.N_,
params_.C_,
params_.K_,
params_.filter_spatial_lengths_,
output_spatial_lengths_);
}
virtual std::size_t GetBtype() const override
{
return get_btype<InDataType, WeiDataType, OutDataType>(params_.N_,
params_.C_,
params_.K_,
params_.input_spatial_lengths_,
params_.filter_spatial_lengths_,
output_spatial_lengths_);
}
private:
const ConvParams& params_;
const std::vector<ck::index_t> output_spatial_lengths_;
const bool do_init_;
InputInitFun input_init_f_;
WeightsInitFun weights_init_f_;
};
} // namespace conv
} // namespace utils
} // namespace ck
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParams& p);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
namespace ck {
namespace utils {
namespace conv {
namespace detail {
template <typename OldLayout>
std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
// separate from legacy code above
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKW>)
{
return {0, 1, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCDHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCZYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKDHW>)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNDHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKZYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWGK>)
{
return {2, 0, 3, 1};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWGK>)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXGC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWGK>)
{
return {4, 0, 5, 1, 2, 3};
}
else
{
printf("%s\n", __func__);
throw std::runtime_error("wrong! unsupported layout");
}
}
} // namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template <typename InLayout>
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNDHWC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWGC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWGC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", InLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<InLayout>());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template <typename WeiLayout>
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKCZYX>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::GKZYXC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXGC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXGC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", WeiLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template <typename OutLayout>
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
// HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function,
// is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK
// TODO: remove this branch after removing legacy kernel
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWGK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWGK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWGK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", OutLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
}
} // namespace conv
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace utils {
namespace conv {
struct ConvParam
{
ConvParam();
ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch,
ck::index_t n_out_channels,
ck::index_t n_in_channels,
const std::vector<ck::index_t>& filters_len,
const std::vector<ck::index_t>& input_len,
const std::vector<ck::index_t>& strides,
const std::vector<ck::index_t>& dilations,
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_;
ck::index_t G_;
ck::index_t N_;
ck::index_t K_;
ck::index_t C_;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> output_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
std::vector<ck::index_t> GetOutputSpatialLengths() const;
std::size_t GetFlops() const;
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::begin(input_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
};
std::string get_conv_param_parser_helper_msg();
ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);
} // namespace conv
} // namespace utils
} // namespace ck
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/host_tensor/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
template <int NDim> template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths, static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
......
...@@ -73,15 +73,21 @@ auto construct_f_unpack_args(F, T args) ...@@ -73,15 +73,21 @@ auto construct_f_unpack_args(F, T args)
struct HostTensorDescriptor struct HostTensorDescriptor
{ {
HostTensorDescriptor() = delete; HostTensorDescriptor() = default;
template <typename X> void CalculateStrides();
HostTensorDescriptor(const std::vector<X>& lens);
template <typename X, typename Y> template <typename X>
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides); HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
void CalculateStrides(); template <typename X>
HostTensorDescriptor(const std::vector<X>& lens) : mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename Range> template <typename Range>
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
...@@ -89,6 +95,19 @@ struct HostTensorDescriptor ...@@ -89,6 +95,19 @@ struct HostTensorDescriptor
this->CalculateStrides(); this->CalculateStrides();
} }
template <typename X, typename Y>
HostTensorDescriptor(const std::initializer_list<X>& lens,
const std::initializer_list<Y>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
template <typename X, typename Y>
HostTensorDescriptor(const std::vector<X>& lens, const std::vector<Y>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
template <typename Range1, typename Range2> template <typename Range1, typename Range2>
HostTensorDescriptor(const Range1& lens, const Range2& strides) HostTensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
...@@ -97,7 +116,7 @@ struct HostTensorDescriptor ...@@ -97,7 +116,7 @@ struct HostTensorDescriptor
std::size_t GetNumOfDimension() const; std::size_t GetNumOfDimension() const;
std::size_t GetElementSize() const; std::size_t GetElementSize() const;
std::size_t GetElementSpace() const; std::size_t GetElementSpaceSize() const;
const std::vector<std::size_t>& GetLengths() const; const std::vector<std::size_t>& GetLengths() const;
const std::vector<std::size_t>& GetStrides() const; const std::vector<std::size_t>& GetStrides() const;
...@@ -122,6 +141,22 @@ struct HostTensorDescriptor ...@@ -122,6 +141,22 @@ struct HostTensorDescriptor
std::vector<std::size_t> mStrides; std::vector<std::size_t> mStrides;
}; };
template <typename New2Old>
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a,
const New2Old& new2old)
{
std::vector<std::size_t> new_lengths(a.GetNumOfDimension());
std::vector<std::size_t> new_strides(a.GetNumOfDimension());
for(std::size_t i = 0; i < a.GetNumOfDimension(); i++)
{
new_lengths[i] = a.GetLengths()[new2old[i]];
new_strides[i] = a.GetStrides()[new2old[i]];
}
return HostTensorDescriptor(new_lengths, new_strides);
}
struct joinable_thread : std::thread struct joinable_thread : std::thread
{ {
template <typename... Xs> template <typename... Xs>
...@@ -203,22 +238,22 @@ template <typename T> ...@@ -203,22 +238,22 @@ template <typename T>
struct Tensor struct Tensor
{ {
template <typename X> template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace()) Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
{ {
} }
template <typename X> template <typename X>
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace()) Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
{ {
} }
template <typename X, typename Y> template <typename X, typename Y>
Tensor(std::vector<X> lens, std::vector<Y> strides) Tensor(std::vector<X> lens, std::vector<Y> strides)
: mDesc(lens, strides), mData(mDesc.GetElementSpace()) : mDesc(lens, strides), mData(mDesc.GetElementSpaceSize())
{ {
} }
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT> template <typename OutT>
Tensor<OutT> CopyAsType() Tensor<OutT> CopyAsType()
...@@ -240,6 +275,24 @@ struct Tensor ...@@ -240,6 +275,24 @@ struct Tensor
return *this; return *this;
} }
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); }
std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); }
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
void SetZero()
{
for(auto& v : mData)
{
v = T{0};
}
}
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
{ {
...@@ -330,6 +383,19 @@ struct Tensor ...@@ -330,6 +383,19 @@ struct Tensor
mDesc.GetLengths()[4])(num_thread); mDesc.GetLengths()[4])(num_thread);
break; break;
} }
case 6: {
auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
(*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4, i5);
};
make_ParallelTensorFunctor(f,
mDesc.GetLengths()[0],
mDesc.GetLengths()[1],
mDesc.GetLengths()[2],
mDesc.GetLengths()[3],
mDesc.GetLengths()[4],
mDesc.GetLengths()[5])(num_thread);
break;
}
default: throw std::runtime_error("unspported dimension"); default: throw std::runtime_error("unspported dimension");
} }
} }
...@@ -367,17 +433,3 @@ struct Tensor ...@@ -367,17 +433,3 @@ struct Tensor
HostTensorDescriptor mDesc; HostTensorDescriptor mDesc;
std::vector<T> mData; std::vector<T> mData;
}; };
template <typename X>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens)
: mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename X, typename Y>
HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens,
const std::vector<Y>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace utils { namespace utils {
...@@ -103,8 +103,8 @@ class OpInstanceRunEngine ...@@ -103,8 +103,8 @@ class OpInstanceRunEngine
} }
} }
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{}); AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ = out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) *
std::make_unique<DeviceMem>(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace()); out_tensor_->mDesc.GetElementSpaceSize());
out_device_buffer_->SetZero(); out_device_buffer_->SetZero();
} }
...@@ -222,7 +222,7 @@ class OpInstanceRunEngine ...@@ -222,7 +222,7 @@ class OpInstanceRunEngine
in_device_buffers_ in_device_buffers_
.emplace_back( .emplace_back(
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) * std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
ts->mDesc.GetElementSpace())) ts->mDesc.GetElementSpaceSize()))
->ToDevice(ts->mData.data()); ->ToDevice(ts->mData.data());
} }
......
## host_tensor
set(HOST_TENSOR_SOURCE
device_memory.cpp
host_tensor.cpp
)
add_library(host_tensor STATIC ${HOST_TENSOR_SOURCE})
add_library(composable_kernel::host_tensor ALIAS host_tensor)
target_compile_features(host_tensor PUBLIC)
set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(host_tensor SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(host_tensor PUBLIC
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>"
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>"
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor>"
)
rocm_install(
TARGETS host_tensor
EXPORT host_tensorTargets
)
rocm_install(
EXPORT host_tensorTargets
FILE composable_kernelhost_tensorTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
clang_tidy_check(host_tensor)
...@@ -16,15 +16,18 @@ add_subdirectory(batched_gemm_reduce) ...@@ -16,15 +16,18 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(contraction_scale) add_subdirectory(contraction_scale)
add_subdirectory(contraction_bilinear) add_subdirectory(contraction_bilinear)
add_subdirectory(conv1d_fwd) add_subdirectory(grouped_conv1d_fwd)
add_subdirectory(grouped_conv2d_fwd)
add_subdirectory(grouped_conv3d_fwd)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
add_subdirectory(conv3d_fwd) add_subdirectory(conv1d_bwd_data)
add_subdirectory(conv2d_fwd_bias_relu)
add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(convnd_bwd_data) add_subdirectory(conv3d_bwd_data)
add_subdirectory(conv1d_bwd_weight)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(convnd_bwd_weight) add_subdirectory(conv3d_bwd_weight)
add_subdirectory(conv2d_fwd_bias_relu)
add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(normalization) add_subdirectory(normalization)
add_subdirectory(elementwise) add_subdirectory(elementwise)
...@@ -40,15 +43,17 @@ add_library(device_operations STATIC ...@@ -40,15 +43,17 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_grouped_gemm_instance> $<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance> $<TARGET_OBJECTS:device_contraction_scale_instance>
$<TARGET_OBJECTS:device_contraction_bilinear_instance> $<TARGET_OBJECTS:device_contraction_bilinear_instance>
$<TARGET_OBJECTS:device_conv1d_fwd_instance> $<TARGET_OBJECTS:device_grouped_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance> $<TARGET_OBJECTS:device_grouped_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance> $<TARGET_OBJECTS:device_grouped_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance> $<TARGET_OBJECTS:device_conv1d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance> $<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance> $<TARGET_OBJECTS:device_conv3d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv1d_bwd_weight_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance> $<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_convnd_bwd_weight_instance> $<TARGET_OBJECTS:device_conv3d_bwd_weight_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_reduce_instance> $<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_normalization_instance> $<TARGET_OBJECTS:device_normalization_instance>
$<TARGET_OBJECTS:device_elementwise_instance> $<TARGET_OBJECTS:device_elementwise_instance>
...@@ -75,7 +80,7 @@ target_include_directories(device_operations PUBLIC ...@@ -75,7 +80,7 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/warp> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/warp>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/utility>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce>
......
...@@ -22,7 +22,7 @@ namespace device { ...@@ -22,7 +22,7 @@ namespace device {
namespace instance { namespace instance {
using F32 = float; using F32 = float;
using F32_TUPLE = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,19 +40,19 @@ using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_in ...@@ -40,19 +40,19 @@ using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_in
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on // clang-format on
>; >;
...@@ -62,7 +62,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn ...@@ -62,7 +62,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
2, 2,
F32, F32,
F32, F32,
F32_TUPLE, F32_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -22,7 +22,7 @@ namespace device { ...@@ -22,7 +22,7 @@ namespace device {
namespace instance { namespace instance {
using F32 = float; using F32 = float;
using F32_TUPLE = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -40,22 +40,22 @@ using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_in ...@@ -40,22 +40,22 @@ using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_in
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_TUPLE, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
// clang-format on // clang-format on
>; >;
...@@ -65,7 +65,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn ...@@ -65,7 +65,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
2, 2,
F32, F32,
F32, F32,
F32_TUPLE, F32_Tuple,
F32, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment