Commit a4fe62ed authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Merge remote-tracking branch 'upstream/develop' into ck_migraphx_integration

parents 08255e1b 3528a523
...@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta ...@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
...@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta ...@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
......
...@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME) ...@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME)
set(FMHA_FWD_FAST_EXP2 true) set(FMHA_FWD_FAST_EXP2 true)
endif() endif()
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND FMHA_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(device_mha_instance PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
target_compile_options(device_mha_instance PRIVATE ${FMHA_COMPILE_OPTIONS})
endif() endif()
target_compile_features(${INSTANCE_NAME} PUBLIC) target_compile_features(${INSTANCE_NAME} PUBLIC)
......
...@@ -9,6 +9,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance ...@@ -9,6 +9,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
# large tensor # large tensor
# NHWGC, GKYXC, NHWGK # NHWGC, GKYXC, NHWGK
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
...@@ -19,6 +22,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance ...@@ -19,6 +22,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
#mem #mem
# NHWGC, GKYXC, NHWGK # NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
...@@ -28,11 +34,20 @@ add_instance_library(device_grouped_conv2d_fwd_instance ...@@ -28,11 +34,20 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
#comp #comp
# NHWGC, GKYXC, NHWGK # NHWGC, GKYXC, NHWGK
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
#dl #dl
# GNHWC, GKYXC, GNHWK # GNHWC, GKYXC, GNHWK
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault,
Intrawave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -32,23 +32,33 @@ if(EXISTS ${FMHA_CPP_FOLDER}/blob_list.txt) ...@@ -32,23 +32,33 @@ if(EXISTS ${FMHA_CPP_FOLDER}/blob_list.txt)
file(REMOVE ${FMHA_CPP_FOLDER}/blob_list.txt) file(REMOVE ${FMHA_CPP_FOLDER}/blob_list.txt)
endif() endif()
set(FMHA_KNOWN_APIS "fwd,fwd_splitkv,fwd_appendkv,bwd")
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
execute_process( execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt --list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
--api ${FMHA_KNOWN_APIS}
--receipt 3
RESULT_VARIABLE ret RESULT_VARIABLE ret
) )
if(ret AND NOT ret EQUAL 0) if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile MHA FAILED to genrate a list of kernels via Python.") message( FATAL_ERROR "CK Tile MHA FAILED to genrate a list of kernels via Python.")
else() else()
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_FWD_GEN_BLOBS) file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_GEN_BLOBS)
endif() endif()
# actually generate the kernel content now # actually generate the kernel content now
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_GEN_BLOBS}
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
--output_dir ${FMHA_CPP_FOLDER} --output_dir ${FMHA_CPP_FOLDER}
--api ${FMHA_KNOWN_APIS}
--receipt 3
COMMENT "Generating mha kernel (cpp) files now ..." COMMENT "Generating mha kernel (cpp) files now ..."
VERBATIM VERBATIM
) )
...@@ -57,12 +67,12 @@ add_custom_command( ...@@ -57,12 +67,12 @@ add_custom_command(
# have filename. Since, it was cauing the cmake # have filename. Since, it was cauing the cmake
# to throw "File name too long" # to throw "File name too long"
set(device_files) set(device_files)
foreach(filepath IN LISTS FMHA_FWD_GEN_BLOBS) foreach(filepath IN LISTS FMHA_GEN_BLOBS)
get_filename_component(filename ${filepath} NAME) get_filename_component(filename ${filepath} NAME)
# Append the filename to the device_files list # Append the filename to the device_files list
list(APPEND device_files ${filename}) list(APPEND device_files ${filename})
endforeach() endforeach()
add_custom_target(generate_cpp_files DEPENDS ${FMHA_FWD_GEN_BLOBS}) add_custom_target(generate_cpp_files DEPENDS ${FMHA_GEN_BLOBS})
add_instance_library(device_mha_instance ${device_files}) add_instance_library(device_mha_instance ${device_files})
......
...@@ -148,6 +148,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -148,6 +148,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool pass = true; bool pass = true;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init output to zero before profiling next kernel // re-init output to zero before profiling next kernel
......
...@@ -45,6 +45,8 @@ static void print_helper_msg() ...@@ -45,6 +45,8 @@ static void print_helper_msg()
"N, Ho, Wo, K]\n" "N, Ho, Wo, K]\n"
<< " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, " << " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, "
"Ho, Wo, G, K]\n" "Ho, Wo, G, K]\n"
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
<< "arg4: verification (0: no, 1: yes)\n" << "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
......
...@@ -15,6 +15,7 @@ enum struct ConvLayout ...@@ -15,6 +15,7 @@ enum struct ConvLayout
{ {
GNHWC_GKYXC_GNHWK, // 0 GNHWC_GKYXC_GNHWK, // 0
NHWGC_GKYXC_NHWGK, // 1 NHWGC_GKYXC_NHWGK, // 1
NGCHW_GKYXC_NGKHW, // 2
}; };
enum struct ConvDataType enum struct ConvDataType
...@@ -54,6 +55,8 @@ static void print_helper_msg() ...@@ -54,6 +55,8 @@ static void print_helper_msg()
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n" << "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
<< "arg5: verification (0: no, 1: yes)\n" << "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
...@@ -111,6 +114,11 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -111,6 +114,11 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
//
using NGCHW = ck::tensor_layout::convolution::NGCHW;
using NGKHW = ck::tensor_layout::convolution::NGKHW;
// //
using NWGC = ck::tensor_layout::convolution::NWGC; using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC; using NHWGC = ck::tensor_layout::convolution::NHWGC;
...@@ -284,6 +292,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -284,6 +292,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
......
...@@ -28,6 +28,8 @@ def parse_layouts(args): ...@@ -28,6 +28,8 @@ def parse_layouts(args):
args.in_layout == "NCDHW": args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight": if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 3 args.layout = 3
elif args.ck_profier_op == "grouped_conv_fwd":
args.layout = 2
else: else:
print('Not supported layout for this op') print('Not supported layout for this op')
exit(1) exit(1)
......
...@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>, ...@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std::tuple<float, NHWGC, GKYXC, NHWGK>, std::tuple<float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>; std::tuple<int8_t, NHWGC, GKYXC, NHWGK>,
std::tuple<float, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>, using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment