Commit c5fb61a9 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Sync with new cmakefile changes

parent fdddc8f4
...@@ -449,11 +449,8 @@ struct DppGemm ...@@ -449,11 +449,8 @@ struct DppGemm
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
{ {
static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value || static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
#if defined CK_ENABLE_FP8 is_same<BaseType, f8_t>::value || is_same<BaseType, int8_t>::value,
|| is_same<BaseType, f8_t>::value
#endif
|| is_same<BaseType, int8_t>::value,
"base BaseType must be double, float, half, bfloat16, and int8_t!"); "base BaseType must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
......
...@@ -415,12 +415,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -415,12 +415,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
#if defined CK_ENABLE_FP8
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
#endif
#if defined CK_ENABLE_BF8
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
#endif
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -541,12 +536,12 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -541,12 +536,12 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
#if defined CK_ENABLE_FP8 // #if defined CK_ENABLE_FP8
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
#endif // #endif
#if defined CK_ENABLE_BF8 // #if defined CK_ENABLE_BF8
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
#endif // #endif
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
......
...@@ -18,12 +18,8 @@ namespace instance { ...@@ -18,12 +18,8 @@ namespace instance {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#if defined CK_ENABLE_BF8
using BF8 = ck::bf8_t; using BF8 = ck::bf8_t;
#endif
#if defined CK_ENABLE_FP8
using F8 = ck::f8_t; using F8 = ck::f8_t;
#endif
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -149,7 +145,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = ...@@ -149,7 +145,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
>; >;
// f16_f16_f16_comp_f8 // f16_f16_f16_comp_f8
#if defined CK_ENABLE_BF8 && defined CK_ENABLE_FP8 // #if defined CK_ENABLE_BF8 && defined CK_ENABLE_FP8
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -185,7 +181,7 @@ using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances = ...@@ -185,7 +181,7 @@ using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, BF8, F8> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, BF8, F8>
// clang-format on // clang-format on
>; >;
#endif // #endif
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
...@@ -50,4 +49,3 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ ...@@ -50,4 +49,3 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -4,8 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT ...@@ -4,8 +4,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
if(DL_KERNELS) if(DL_KERNELS)
list(APPEND GROUPED_CONV3D_BWD_WEIGHT list(APPEND GROUPED_CONV3D_BWD_WEIGHT
...@@ -27,4 +26,9 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT ...@@ -27,4 +26,9 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp) wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
endif()
add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT}) add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT})
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...@@ -54,4 +53,3 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance ...@@ -54,4 +53,3 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment