Commit c6891e12 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into standalone-layernorm

parents f591ad27 8e374781
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ...@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out, ...@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
const std::vector<T>& ref, const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double = 0, double = 0,
double = 0) double atol = 0)
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
...@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out, ...@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
int64_t r = ref[i]; int64_t r = ref[i];
err = std::abs(o - r); err = std::abs(o - r);
if(err > 0) if(err > atol)
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
......
...@@ -31,15 +31,15 @@ namespace device { ...@@ -31,15 +31,15 @@ namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough, using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough, element_wise::PassThrough,
element_wise::PassThrough>; element_wise::PassThrough>;
namespace device_conv1d_fwd_instance { namespace instance {
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv1d_fwd_instance } // namespace instance
namespace device_conv2d_fwd_instance { namespace instance {
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
...@@ -48,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( ...@@ -48,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance } // namespace instance
namespace device_conv3d_fwd_instance { namespace instance {
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv3d_fwd_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
...@@ -295,17 +295,17 @@ struct ConvolutionFwdInstances<float, float, float> ...@@ -295,17 +295,17 @@ struct ConvolutionFwdInstances<float, float, float>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -322,20 +322,20 @@ struct ConvolutionFwdInstances<half_t, half_t, half_t> ...@@ -322,20 +322,20 @@ struct ConvolutionFwdInstances<half_t, half_t, half_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
return conv_ptrs; return conv_ptrs;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -352,17 +352,17 @@ struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t> ...@@ -352,17 +352,17 @@ struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -379,17 +379,17 @@ struct ConvolutionFwdInstances<int8_t, int8_t, int8_t> ...@@ -379,17 +379,17 @@ struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
......
...@@ -25,6 +25,7 @@ add_subdirectory(conv2d_fwd_bias_relu_add) ...@@ -25,6 +25,7 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(normalization)
add_subdirectory(reduce) add_subdirectory(reduce)
add_library(device_operations STATIC add_library(device_operations STATIC
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -28,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -28,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
// clang-format on // clang-format on
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -48,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< ...@@ -48,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -53,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< ...@@ -53,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< ...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Row,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< ...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Col,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< ...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceBatchedGemm<Row,
Row,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -51,13 +51,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< ...@@ -51,13 +51,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceBatchedGemm<Row,
Col,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment