Unverified Commit ced5af16 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Extend support for contraction 6D (#1207)

* Extend support for contraction up to 5D

* Extend contraction bilinear instances

* Fix interface test

* Add 6d support, remove 3d,4d,5d

* Fixes

* Fix readme

* Make defualt dim for contraction instances
parent 366592b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]
// m/n/n/n are the fast changing dimension for A/B/D/E
using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance =
device_contraction_f64_mn_instance<F64,
F64,
F64,
F64,
F64_Tuple,
F64,
F64,
PassThrough,
PassThrough,
Bilinear,
6>;
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F64,
F64,
F64_Tuple,
F64,
PassThrough,
PassThrough,
Bilinear,
F64>>>& instances)
{
add_device_operation_instances(
instances,
device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
# FP32
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp)
# FP64
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp)
# FP16
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp)
# BF16
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp)
list(APPEND DIMS 2 6)
foreach(idx IN LISTS DIMS)
set(PREFIX ${idx}D/device_contraction_bilinear_m${idx}_n${idx}_k${idx})
# FP32
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp)
# FP64
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp)
# FP16
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp)
# BF16
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp
${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp
${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp
${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp)
endforeach()
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kk
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kn
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mk
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mn
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_k
BF16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_k
BF16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_m
BF16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_m
BF16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kk
F16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kn
F16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mk
F16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mn
F16,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance =
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// This (ifndef) is a hack to use customized behavior for buffer load rather than using default
// setting Don't use this hack unless absolutely necessary!
......@@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance =
F32,
PassThrough,
PassThrough,
Scale>;
Scale,
2>;
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment