Commit 41b920e2 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 874a78f9 5d718e6b
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp" #include "convnd_fwd_activ_multi_ab_common.hpp"
...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>; ...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType, using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType, AccDataType,
ADataTypes, ADataTypes,
BDataTypes, BDataTypes,
InElementOp, InElementOp,
WeiElementOp>; WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc" #include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp" #include "convnd_fwd_activ_multi_ab_common.hpp"
...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>; ...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType, using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType, AccDataType,
ADataTypes, ADataTypes,
BDataTypes, BDataTypes,
InElementOp, InElementOp,
WeiElementOp>; WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc" #include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp" #include "convnd_fwd_activ_multi_ab_common.hpp"
...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>; ...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType, using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType, AccDataType,
ADataTypes, ADataTypes,
BDataTypes, BDataTypes,
InElementOp, InElementOp,
WeiElementOp>; WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc" #include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp" #include "convnd_fwd_activ_multi_ab_common.hpp"
...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>; ...@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType, using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType, AccDataType,
ADataTypes, ADataTypes,
BDataTypes, BDataTypes,
InElementOp, InElementOp,
WeiElementOp>; WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc" #include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
...@@ -100,16 +100,16 @@ template <ck::index_t NDimSpatial, ...@@ -100,16 +100,16 @@ template <ck::index_t NDimSpatial,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename DeviceConvNDFwdInstance> typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification, bool run_grouped_conv(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param, const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc, const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc, const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc, const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op, const InElementOp& in_element_op,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
constexpr ck::index_t NumAs = 2; constexpr ck::index_t NumAs = 2;
constexpr ck::index_t NumBs = 2; constexpr ck::index_t NumBs = 2;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -11,7 +11,7 @@ void print_helper_msg() ...@@ -11,7 +11,7 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
} }
bool run_convnd_fwd_example(int argc, char* argv[]) bool run_convnd_example(int argc, char* argv[])
{ {
print_helper_msg(); print_helper_msg();
...@@ -63,23 +63,23 @@ bool run_convnd_fwd_example(int argc, char* argv[]) ...@@ -63,23 +63,23 @@ bool run_convnd_fwd_example(int argc, char* argv[])
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>( ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param); conv_param);
return run_grouped_conv_fwd<NDimSpatial, return run_grouped_conv<NDimSpatial,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
DeviceGroupedConvNDFwdActivInstance>(do_verification, DeviceGroupedConvNDActivInstance>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
conv_param, conv_param,
in_g_n_c_wis_desc, in_g_n_c_wis_desc,
wei_g_k_c_xs_desc, wei_g_k_c_xs_desc,
out_g_n_k_wos_desc, out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
}; };
if(conv_param.num_dim_spatial_ == 3) if(conv_param.num_dim_spatial_ == 3)
......
...@@ -2,48 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,48 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_activ_xdl) add_custom_target(example_convnd_activ_unary_xdl)
# Sigmoid # Sigmoid
add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16)
# Tanh # Tanh
add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16)
# Relu # Relu
add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16)
# SoftRelu # SoftRelu
add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16)
# Abs # Abs
add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16)
# Pow # Pow
add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16)
# Clipped Relu # Clipped Relu
add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
# Leaky Relu # Leaky Relu
add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
# Elu # Elu
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16)
# ScaleAdd on A and B
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp32)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -102,16 +102,16 @@ template <ck::index_t NDimSpatial, ...@@ -102,16 +102,16 @@ template <ck::index_t NDimSpatial,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename DeviceConvNDFwdInstance> typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification, bool run_grouped_conv(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param, const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc, const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc, const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc, const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op, const InElementOp& in_element_op,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
Tensor<InDataType> in(in_g_n_c_wis_desc); Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc); Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Elu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Power;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Relu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Sigmoid;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::SoftRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::TanH;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
...@@ -44,16 +44,30 @@ ...@@ -44,16 +44,30 @@
#define CK_USE_WAVES_PER_EU 0 #define CK_USE_WAVES_PER_EU 0
#endif #endif
// define general macros for various architectures
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
#define __gfx101__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ defined(__gfx90a__) || defined(__gfx94__)
defined(__gfx942__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx11__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif #endif
...@@ -61,12 +75,12 @@ ...@@ -61,12 +75,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32 #define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ #elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code defined(__gfx94__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8 #define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) #elif defined(__gfx11__)
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...@@ -75,23 +89,22 @@ ...@@ -75,23 +89,22 @@
// MFMA instruction // MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA #define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_MFMA #define CK_USE_AMD_MFMA
#endif #endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) #if(defined(__gfx90a__) || defined(__gfx94__))
#define CK_USE_AMD_MFMA_BF16_1K_OP #define CK_USE_AMD_MFMA_BF16_1K_OP
#endif #endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx94__)
#define CK_USE_AMD_MFMA_GFX940 #define CK_USE_AMD_MFMA_GFX940
#endif #endif
// WMMA instruction // WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA #define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code #elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA #define CK_USE_AMD_WMMA
#endif #endif
...@@ -107,15 +120,13 @@ ...@@ -107,15 +120,13 @@
// buffer atomic add: floating point // buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
defined(__gfx942__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code #else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif #endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
defined(__gfx942__)) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else #else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
......
...@@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported() ...@@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
inline bool is_navi1_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
ck::get_device_name() == "gfx1012";
}
inline bool is_navi2_supported()
{
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_navi3_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
} // namespace ck } // namespace ck
...@@ -37,7 +37,9 @@ template <index_t BlockSize, ...@@ -37,7 +37,9 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{}; static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf); b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) { static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatA, KPack> a_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec; vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
}); });
using mfma_input_type_a = using mfma_input_type_a =
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b = using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA, ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPerThread>,
...@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>; A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB, ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPerThread>,
...@@ -398,6 +401,8 @@ template <index_t BlockSize, ...@@ -398,6 +401,8 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS> index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
...@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack,
ComputeTypeA,
ComputeTypeB>
{ {
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA, FloatA,
...@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack,
ComputeTypeA,
ComputeTypeB>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k; using Base::a_block_desc_m0_m1_m2_k;
...@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
...@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatA, KPack> a_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec; vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_vec.template AsType<ComputeTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}]; make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_vec.template AsType<ComputeTypeB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}]; make_tuple(n0, 0, 0, k_ + i))>{}];
}); });
using mfma_input_type_a = using mfma_input_type_a =
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b = using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{})); make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA, ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>, Sequence<1, 1, 1, KPerInnerLoop>,
...@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>; A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB, ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>, Sequence<1, 1, 1, KPerInnerLoop>,
...@@ -586,7 +595,9 @@ template <index_t BlockSize, ...@@ -586,7 +595,9 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack, index_t KPack,
LoopScheduler LoopSched> LoopScheduler LoopSched,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{ {
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
...@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
ComputeTypeA,
ComputeTypeB>{};
} }
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
...@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
ComputeTypeA,
ComputeTypeB>{};
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment