"third_party/vscode:/vscode.git/clone" did not exist on "575478ecf4a3c9f7b2593f6efc5e7c8c122b237b"
Commit e57bd240 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into add_fp16_wmma_conv_instance
parents 4b70d68e 37a8c1f7
...@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise( ...@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise(
{ {
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) { [&](auto i) {
using cfg1 = remove_cvref_t<decltype( using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>; reduce_configuration_1_instances_blockwise{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) { [&](auto j) {
using cfg2 = remove_cvref_t<decltype( using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>; reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance = using ReduceOpInstance =
DeviceReduceMultiBlock<InDataType, DeviceReduceMultiBlock<InDataType,
......
...@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_for<0, static_for<0,
std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value, std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value,
1>{}([&](auto i) { 1>{}([&](auto i) {
using cfg1 = remove_cvref_t<decltype( using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
std::get<i.value>(reduce_configuration_1_instances_multiblock_atomic_add{}))>; reduce_configuration_1_instances_multiblock_atomic_add{}))>;
static_for<0, static_for<0,
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value, std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
1>{}([&](auto j) { 1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype( using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
std::get<j.value>(reduce_configuration_2_instances_multiblock_atomic_add{}))>; reduce_configuration_2_instances_multiblock_atomic_add{}))>;
using ReduceOpInstance = DeviceReduceMultiBlock<InDataType, using ReduceOpInstance = DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
......
...@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise( ...@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise(
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
[&](auto j) { [&](auto j) {
using cfg2 = remove_cvref_t<decltype( using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>; reduce_configuration_2_instances_threadwise{}))>;
using ReduceOpInstance = DeviceReduceThreadWise<InDataType, using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
AccDataType, AccDataType,
......
...@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> && if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F16>) std::is_same_v<OutDataType, F16>)
{ {
...@@ -65,8 +65,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -65,8 +65,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs); add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs);
} }
} }
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> && #endif
std::is_same_v<OutDataType, F32>) #ifdef CK_ENABLE_FP32
if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
{ {
...@@ -89,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -89,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs); add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -65,7 +65,11 @@ check_err(const Range& out, ...@@ -65,7 +65,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
...@@ -112,7 +116,11 @@ check_err(const Range& out, ...@@ -112,7 +116,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
...@@ -158,7 +166,11 @@ check_err(const Range& out, ...@@ -158,7 +166,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
...@@ -209,7 +221,11 @@ check_err(const Range& out, ...@@ -209,7 +221,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
......
...@@ -20,12 +20,15 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) ...@@ -20,12 +20,15 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
*/ */
struct DeviceMem struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size); DeviceMem(std::size_t mem_size);
void Realloc(std::size_t mem_size);
void* GetDeviceBuffer() const; void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const; void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const; void SetZero() const;
template <typename T> template <typename T>
void SetValue(T x) const; void SetValue(T x) const;
......
...@@ -102,9 +102,10 @@ struct FillMonotonicSeq ...@@ -102,9 +102,10 @@ struct FillMonotonicSeq
} }
template <typename ForwardRange> template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<decltype( auto operator()(ForwardRange&& range) const
std::declval<const FillMonotonicSeq&>()(std::begin(std::forward<ForwardRange>(range)), -> std::void_t<decltype(std::declval<const FillMonotonicSeq&>()(
std::end(std::forward<ForwardRange>(range))))> std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{ {
(*this)(std::begin(std::forward<ForwardRange>(range)), (*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))); std::end(std::forward<ForwardRange>(range)));
......
...@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t> ...@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_2<ck::f8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f8_t>(tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t> ...@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_3<ck::f8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f8_t>(fp32_tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
...@@ -39,11 +39,15 @@ IF(IS_DIRECTORY "${subdir_path}") ...@@ -39,11 +39,15 @@ IF(IS_DIRECTORY "${subdir_path}")
#message("int8 instance found!") #message("int8 instance found!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES") if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
#message("instance should be built for all types!") #message("instance should be built for all types!")
set(add_inst 1) set(add_inst 1)
endif() endif()
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if(add_inst EQUAL 1)
get_filename_component(target_dir ${subdir_path} NAME) get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir}) add_subdirectory(${target_dir})
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>) list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
......
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I32 = int32_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using device_avgpool_bwd_ndhwc_f16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_bf16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_f32_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "pool_fwd_instance_common.hpp" #include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; void add_device_avgpool_bwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, BF16, BF16, NDHWC, NDHWC>>>& instances)
void add_device_pool2d_fwd_nhwc_f16_instances( {
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances) add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_bf16_instances{});
{ }
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{}); } // namespace instance
} } // namespace device
} // namespace tensor_operation
} // namespace instance } // namespace ck
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "pool_fwd_instance_common.hpp" #include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; void add_device_avgpool_bwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F16, F16, NDHWC, NDHWC>>>& instances)
void add_device_pool2d_fwd_nhwc_f32_instances( {
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances) add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f16_instances{});
{ }
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{}); } // namespace instance
} } // namespace device
} // namespace tensor_operation
void add_device_pool2d_fwd_nhwc_index_f32_instances( } // namespace ck
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, true>>>& instances)
{
add_device_operation_instances(
instances, device_pool2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F32, F32, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
add_instance_library(device_batched_gemm_instance set(BATCHED_GEMM_INSTANCES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp endif()
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
) device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
...@@ -25,6 +25,16 @@ using S = ck::Sequence<Is...>; ...@@ -25,6 +25,16 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
>;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple<
// clang-format off // clang-format off
...@@ -100,6 +110,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( ...@@ -100,6 +110,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(
instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{});
} }
......
...@@ -25,6 +25,16 @@ using S = ck::Sequence<Is...>; ...@@ -25,6 +25,16 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
>;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple<
// clang-format off // clang-format off
...@@ -88,6 +98,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( ...@@ -88,6 +98,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(
instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{});
} }
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
\ No newline at end of file
...@@ -31,22 +31,24 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add; ...@@ -31,22 +31,24 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcVectorDim| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
//generic
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
// no padding // no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel // Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8> DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
......
...@@ -31,23 +31,25 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add; ...@@ -31,23 +31,25 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcVectorDim| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
//generic
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
// no padding // no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel // Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8> DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment