"desktop/postcss.config.js" did not exist on "ebec1c61db24859c415eed66d31827f8fce9e744"
Commit bd0f0686 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents e9b1000f 63914743
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -24,7 +24,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_threadwise = std::tuple<
......@@ -151,7 +151,7 @@ void add_device_reduce_instance_threadwise(
Rank, \
NumReduceDim)
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -53,7 +53,7 @@ ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -28,7 +28,7 @@ ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -52,7 +52,7 @@ ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -28,7 +28,7 @@ ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -52,7 +52,7 @@ ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
namespace instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
......@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double = 0)
double atol = 0)
{
if(out.size() != ref.size())
{
......@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
int64_t r = ref[i];
err = std::abs(o - r);
if(err > 0)
if(err > atol)
{
max_err = err > max_err ? err : max_err;
err_count++;
......
......@@ -31,15 +31,15 @@ namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough,
element_wise::PassThrough>;
namespace device_conv1d_fwd_instance {
namespace instance {
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv1d_fwd_instance
namespace device_conv2d_fwd_instance {
} // namespace instance
namespace instance {
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
......@@ -48,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance
namespace device_conv3d_fwd_instance {
} // namespace instance
namespace instance {
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv3d_fwd_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
......@@ -295,17 +295,17 @@ struct ConvolutionFwdInstances<float, float, float>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
}
return conv_ptrs;
......@@ -322,20 +322,20 @@ struct ConvolutionFwdInstances<half_t, half_t, half_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
return conv_ptrs;
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
}
return conv_ptrs;
......@@ -352,17 +352,17 @@ struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
}
return conv_ptrs;
......@@ -379,17 +379,17 @@ struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1)
{
ck::tensor_operation::device::device_conv1d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 2)
{
ck::tensor_operation::device::device_conv2d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
}
else if constexpr(NumDimSpatial == 3)
{
ck::tensor_operation::device::device_conv3d_fwd_instance::
ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
}
return conv_ptrs;
......
......@@ -5,18 +5,17 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(elementwise)
add_subdirectory(gemm)
add_subdirectory(gemm_splitk)
add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_bilinear)
add_subdirectory(gemm_add_add_fastgelu)
add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm)
add_subdirectory(contraction_scale)
add_subdirectory(contraction_bilinear)
add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd)
add_subdirectory(conv3d_fwd)
......@@ -25,19 +24,22 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data)
add_subdirectory(convnd_bwd_data)
add_subdirectory(conv2d_bwd_weight)
add_subdirectory(convnd_bwd_weight)
add_subdirectory(reduce)
add_subdirectory(normalization)
add_subdirectory(elementwise)
add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_gemm_bilinear_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance>
$<TARGET_OBJECTS:device_contraction_bilinear_instance>
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
......@@ -46,9 +48,10 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_convnd_bwd_weight_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_normalization_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
)
add_library(composablekernels::device_operations ALIAS device_operations)
......@@ -80,7 +83,6 @@ target_include_directories(device_operations PUBLIC
#once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported
target_compile_options(device_operations PRIVATE
--offload-arch=gfx908
--offload-arch=gfx90a
......
......@@ -7,12 +7,13 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -28,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
// clang-format on
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -48,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
......@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
......@@ -53,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
>;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment