Unverified Commit aebd211c authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

External Interface (#304)

* add client example

* clean

* clean

* reorg

* clean up profiler

* reorg

* clea

* fix profiler

* function for getinstances

* update client example

* update client example

* update client example

* update

* update example

* update Jenkins file

* update cmake

* update Jenkins
parent b653c5eb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto get_device_gemm_splitk_instances()
{
std::vector<DeviceGemmSplitKNoOpPtr> op_ptrs;
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) ...@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
return os; return os;
} }
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
{
os << "dim " << desc.GetNumOfDimension() << ", ";
os << "lengths {";
LogRange(os, desc.GetLengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.GetStrides(), ", ");
os << "}" << std::endl;
}
#if 1
// FIXME: remove
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst)
{
for(std::size_t i = 0; i < src.mData.size(); ++i)
dst.mData[i] = ck::type_convert<float>(src.mData[i]);
}
#endif
...@@ -6,43 +6,45 @@ function(add_instance_library INSTANCE_NAME) ...@@ -6,43 +6,45 @@ function(add_instance_library INSTANCE_NAME)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_splitk)
add_subdirectory(gemm_bias2d) add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce) add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm)
add_subdirectory(conv1d_fwd) add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
add_subdirectory(conv3d_fwd) add_subdirectory(conv3d_fwd)
add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu)
add_subdirectory(conv2d_fwd_bias_relu_add) add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(reduce)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_gemm)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(batched_gemm_reduce) add_subdirectory(reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_library(device_operations STATIC add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_instance> $<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance> $<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance> $<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance> $<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_reduce_instance> $<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance> $<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance> $<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance> $<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance> $<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_reduce_instance>
) )
add_library(composablekernels::device_operations ALIAS device_operations) add_library(composablekernels::device_operations ALIAS device_operations)
...@@ -67,8 +69,8 @@ target_include_directories(device_operations PUBLIC ...@@ -67,8 +69,8 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce>
) )
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{});
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{});
......
...@@ -48,7 +48,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< ...@@ -48,7 +48,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{});
......
...@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{});
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{});
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{});
......
...@@ -53,7 +53,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< ...@@ -53,7 +53,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
......
...@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{});
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{});
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{});
......
...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< ...@@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{});
......
...@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{});
......
...@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< ...@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{});
......
...@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< ...@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{});
......
...@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< ...@@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{});
......
...@@ -51,7 +51,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< ...@@ -51,7 +51,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceBatchedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{});
......
...@@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in ...@@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector< std::vector<DeviceBatchedGemmReducePtr<PassThrough,
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>& PassThrough,
instances) PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment