Commit 41b920e2 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 874a78f9 5d718e6b
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp"
......@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc"
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp"
......@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc"
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp"
......@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc"
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_multi_ab_common.hpp"
......@@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple<DataType, DataType>;
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
AccDataType,
ADataTypes,
BDataTypes,
InElementOp,
WeiElementOp>;
#include "../run_convnd_fwd_activ_example.inc"
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......@@ -100,16 +100,16 @@ template <ck::index_t NDimSpatial,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
bool run_grouped_conv(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
constexpr ck::index_t NumAs = 2;
constexpr ck::index_t NumBs = 2;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -11,7 +11,7 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
bool run_convnd_fwd_example(int argc, char* argv[])
bool run_convnd_example(int argc, char* argv[])
{
print_helper_msg();
......@@ -63,23 +63,23 @@ bool run_convnd_fwd_example(int argc, char* argv[])
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdActivInstance>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
return run_grouped_conv<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDActivInstance>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
};
if(conv_param.num_dim_spatial_ == 3)
......
......@@ -2,48 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_activ_xdl)
add_custom_target(example_convnd_activ_unary_xdl)
# Sigmoid
add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16)
# Tanh
add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16)
# Relu
add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16)
# SoftRelu
add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16)
# Abs
add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16)
# Pow
add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16)
# Clipped Relu
add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
# Leaky Relu
add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
# Elu
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16)
# ScaleAdd on A and B
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp32)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
# ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -102,16 +102,16 @@ template <ck::index_t NDimSpatial,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
bool run_grouped_conv(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Elu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Power;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Relu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Sigmoid;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::SoftRelu;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_unary_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::TanH;
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "../run_convnd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
......@@ -44,16 +44,30 @@
#define CK_USE_WAVES_PER_EU 0
#endif
// define general macros for various architectures
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
#define __gfx101__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
#define __gfx103__
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
defined(__gfx90a__) || defined(__gfx94__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx11__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
......@@ -61,12 +75,12 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
defined(__gfx94__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#elif defined(__gfx11__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
......@@ -75,23 +89,22 @@
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
#define CK_USE_AMD_MFMA
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(defined(__gfx90a__) || defined(__gfx94__))
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx94__)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
......@@ -107,15 +120,13 @@
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
......
......@@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
inline bool is_navi1_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
ck::get_device_name() == "gfx1012";
}
inline bool is_navi2_supported()
{
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_navi3_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
}
} // namespace ck
......@@ -37,7 +37,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{};
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
......@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
......@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type_a =
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
......@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
......@@ -398,6 +401,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
ComputeTypeA,
ComputeTypeB>
{
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatA,
......@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
ComputeTypeA,
ComputeTypeB>;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using Base::a_block_desc_m0_m1_m2_k;
......@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
......@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatB, KPack> b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_vec.template AsType<ComputeTypeA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_vec.template AsType<ComputeTypeB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}];
});
using mfma_input_type_a =
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
......@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
......@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
......@@ -586,7 +595,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
LoopScheduler LoopSched,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = FloatB>
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
......@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
else if constexpr(LoopSched == LoopScheduler::Interwave)
{
......@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ComputeTypeA,
ComputeTypeB>{};
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment