Commit c6891e12 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into standalone-layernorm

parents f591ad27 8e374781
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double = 0)
double atol = 0)
{
if(out.size() != ref.size())
{
......@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
int64_t r = ref[i];
err = std::abs(o - r);
if(err > 0)
if(err > atol)
{
max_err = err > max_err ? err : max_err;
err_count++;
......
......@@ -31,15 +31,15 @@ namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough,
element_wise::PassThrough>;
namespace device_conv1d_fwd_instance {
namespace instance {
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv1d_fwd_instance
namespace device_conv2d_fwd_instance {
} // namespace instance
namespace instance {
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
......@@ -48,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance
namespace device_conv3d_fwd_instance {
} // namespace instance
namespace instance {
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv3d_fwd_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
......@@ -295,17 +295,17 @@ struct ConvolutionFwdInstances<float, float, float>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
}
return conv_ptrs;
......@@ -322,20 +322,20 @@ struct ConvolutionFwdInstances<half_t, half_t, half_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
return conv_ptrs;
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
}
return conv_ptrs;
......@@ -352,17 +352,17 @@ struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
}
return conv_ptrs;
......@@ -379,17 +379,17 @@ struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
}
return conv_ptrs;
......
......@@ -25,6 +25,7 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data)
add_subdirectory(convnd_bwd_data)
add_subdirectory(conv2d_bwd_weight)
add_subdirectory(normalization)
add_subdirectory(reduce)
add_library(device_operations STATIC
......
......@@ -7,12 +7,13 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -28,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
// clang-format on
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -48,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -53,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Row,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Col,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<DeviceBatchedGemm<Row,
Row,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -51,13 +51,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<DeviceBatchedGemm<Row,
Col,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment