Commit 6ef4e211 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into contraction

parents b0a2afb9 9e4429f9
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); ...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); ADD_THREADWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ...@@ -40,7 +40,7 @@ ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out, ...@@ -159,7 +159,7 @@ check_err(const std::vector<T>& out,
const std::vector<T>& ref, const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double = 0, double = 0,
double = 0) double atol = 0)
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
...@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out, ...@@ -179,7 +179,7 @@ check_err(const std::vector<T>& out,
int64_t r = ref[i]; int64_t r = ref[i];
err = std::abs(o - r); err = std::abs(o - r);
if(err > 0) if(err > atol)
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
......
...@@ -31,15 +31,15 @@ namespace device { ...@@ -31,15 +31,15 @@ namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough, using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough, element_wise::PassThrough,
element_wise::PassThrough>; element_wise::PassThrough>;
namespace device_conv1d_fwd_instance { namespace instance {
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv1d_fwd_instance } // namespace instance
namespace device_conv2d_fwd_instance { namespace instance {
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
...@@ -48,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( ...@@ -48,15 +48,15 @@ void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance } // namespace instance
namespace device_conv3d_fwd_instance { namespace instance {
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv3d_fwd_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
...@@ -295,17 +295,17 @@ struct ConvolutionFwdInstances<float, float, float> ...@@ -295,17 +295,17 @@ struct ConvolutionFwdInstances<float, float, float>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -322,20 +322,20 @@ struct ConvolutionFwdInstances<half_t, half_t, half_t> ...@@ -322,20 +322,20 @@ struct ConvolutionFwdInstances<half_t, half_t, half_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
return conv_ptrs; return conv_ptrs;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -352,17 +352,17 @@ struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t> ...@@ -352,17 +352,17 @@ struct ConvolutionFwdInstances<bhalf_t, bhalf_t, bhalf_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -379,17 +379,17 @@ struct ConvolutionFwdInstances<int8_t, int8_t, int8_t> ...@@ -379,17 +379,17 @@ struct ConvolutionFwdInstances<int8_t, int8_t, int8_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
ck::tensor_operation::device::device_conv1d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
ck::tensor_operation::device::device_conv3d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
......
...@@ -11,22 +11,20 @@ target_compile_features(host_tensor PUBLIC) ...@@ -11,22 +11,20 @@ target_compile_features(host_tensor PUBLIC)
set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(host_tensor SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(host_tensor SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_include_directories(host_tensor PUBLIC target_include_directories(host_tensor PUBLIC
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>" "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>"
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>" "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>"
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor>" "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor>"
) )
install(TARGETS host_tensor rocm_install(
EXPORT host_tensorTargets TARGETS host_tensor
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} EXPORT host_tensorTargets
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
) )
install(EXPORT host_tensorTargets rocm_install(
FILE composable_kernelhost_tensorTargets.cmake EXPORT host_tensorTargets
FILE composable_kernelhost_tensorTargets.cmake
NAMESPACE composable_kernel:: NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
) )
......
...@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) ...@@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
return os; return os;
} }
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
{
os << "dim " << desc.GetNumOfDimension() << ", ";
os << "lengths {";
LogRange(os, desc.GetLengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.GetStrides(), ", ");
os << "}" << std::endl;
}
#if 1
// FIXME: remove
void bf16_to_f32_(const Tensor<ck::bhalf_t>& src, Tensor<float>& dst)
{
for(std::size_t i = 0; i < src.mData.size(); ++i)
dst.mData[i] = ck::type_convert<float>(src.mData[i]);
}
#endif
...@@ -6,43 +6,46 @@ function(add_instance_library INSTANCE_NAME) ...@@ -6,43 +6,46 @@ function(add_instance_library INSTANCE_NAME)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_bias2d) add_subdirectory(gemm_splitk)
add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bilinear)
add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_add_add_fastgelu)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce) add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm)
add_subdirectory(conv1d_fwd) add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
add_subdirectory(conv3d_fwd) add_subdirectory(conv3d_fwd)
add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu)
add_subdirectory(conv2d_fwd_bias_relu_add) add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(reduce)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_gemm)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(batched_gemm_reduce) add_subdirectory(reduce)
add_subdirectory(gemm_add_add_fastgelu) add_subdirectory(normalization)
add_subdirectory(elementwise)
add_library(device_operations STATIC add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance> $<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bilinear_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance> $<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance> $<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance> $<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance> $<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance> $<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_instance> $<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance> $<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance> $<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance> $<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_conv3d_fwd_instance> $<TARGET_OBJECTS:device_normalization_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance> $<TARGET_OBJECTS:device_elementwise_instance>
) )
add_library(composablekernels::device_operations ALIAS device_operations) add_library(composablekernels::device_operations ALIAS device_operations)
...@@ -67,29 +70,24 @@ target_include_directories(device_operations PUBLIC ...@@ -67,29 +70,24 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/thread>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/element>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host_tensor>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/host>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/reduce>
) )
#once new arches are enabled make this an option on the main cmake file #once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported # and pass down here to be exported
target_compile_options(device_operations PRIVATE
target_compile_options(device_operations PRIVATE
--offload-arch=gfx908 --offload-arch=gfx908
--offload-arch=gfx90a --offload-arch=gfx90a
) )
# install(TARGETS device_operations LIBRARY DESTINATION lib) # install(TARGETS device_operations LIBRARY DESTINATION lib)
install(TARGETS device_operations rocm_install(TARGETS device_operations
EXPORT device_operationsTargets EXPORT device_operationsTargets)
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} rocm_install(EXPORT device_operationsTargets
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
install(EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel:: NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -28,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -28,29 +29,31 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> DeviceBatchedGemmXdl< BF16, BF16, BF16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
// clang-format on // clang-format on
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -48,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< ...@@ -48,13 +48,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -53,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< ...@@ -53,13 +53,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< ...@@ -44,13 +44,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< ...@@ -49,13 +49,15 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_batched_gemm_instance { namespace instance {
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< ...@@ -59,13 +59,21 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Row,
Row,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment