Commit 9f1b4276 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents 711857c4 c7010716
# ONLY XDL_KERNELS
add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
......
# ONLY XDL_KERNELS
add_instance_library(device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
......
# ONLY XDL_KERNELS
set(GROUPED_GEMM_FIXED_NK_INSTANCES)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
......
# ONLY XDL_AND_DL_KERNELS
set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp)
set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp)
set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -31,7 +31,9 @@ template <ck::index_t NDimSpatial,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
typename OutDataType,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool profile_grouped_conv_fwd_impl(int do_verification,
int init_method,
bool do_log,
......@@ -209,7 +211,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
OutElementOp,
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
......
......@@ -14,6 +14,8 @@
#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
......@@ -21,14 +23,6 @@
#include "ck/library/utility/literals.hpp"
namespace ck {
template <typename HostTensorA, typename HostTensorB, typename ElementOp>
void reference_permute_scale(HostTensorB& b_tensor,
const HostTensorA& a_tensor,
ElementOp tensor_op)
{
b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); });
}
namespace profiler {
template <typename ADataType, typename BDataType, index_t NumDim>
......@@ -46,7 +40,8 @@ bool profile_permute_scale_impl(int do_verification,
using ElementOp = ck::tensor_operation::element_wise::Scale;
float scale = 2.f;
Tensor<ADataType> a(lengths_vector, input_strides_vector);
std::array<Tensor<ADataType>, 1> as = {Tensor<ADataType>(lengths_vector, input_strides_vector)};
Tensor<ADataType>& a = as[0];
Tensor<BDataType> b(lengths_vector, output_strides_vector);
Tensor<BDataType> host_b(lengths_vector, output_strides_vector);
......@@ -83,7 +78,14 @@ bool profile_permute_scale_impl(int do_verification,
if(do_verification)
{
reference_permute_scale(host_b, a, ElementOp{scale});
using ReferenceElementwiseInstance =
ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, ElementOp>;
auto ref_elementwise = ReferenceElementwiseInstance{};
auto ref_invoker = ref_elementwise.MakeInvoker();
auto ref_argument = ref_elementwise.MakeArgument(as, host_b, ElementOp{scale});
ref_invoker.Run(ref_argument);
}
auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
......
......@@ -2,19 +2,6 @@
set(PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
......@@ -29,16 +16,47 @@ set(PROFILER_SOURCES
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
)
if(DL_KERNELS)
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
if(GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp)
endif()
<<<<<<< HEAD
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
......@@ -55,11 +73,20 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
=======
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
>>>>>>> origin/develop
endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
if(DL_KERNELS)
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler)
......@@ -69,25 +96,6 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
......@@ -97,39 +105,65 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
if(GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
endif()
if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
......@@ -25,6 +25,7 @@ enum struct ConvDataType
INT8_INT8_INT8, // 3
F8_F8_F8, // 4
BF8_BF8_F8, // 5
F8_BF8_F8, // 6
};
#define OP_NAME "grouped_conv_fwd"
......@@ -40,7 +41,8 @@ static void print_helper_msg()
<< " 2: Input bf16, Weight bf16, Output bf16\n"
<< " 3: Input int8, Weight int8, Output int8\n"
<< " 4: Input fp8, Weight fp8, Output fp8\n"
<< " 5: Input bf8, Weight bf8, Output fp8)\n"
<< " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8)\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg4: verification (0: no, 1: yes)\n"
......@@ -118,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
auto out_layout,
auto in_type,
auto wei_type,
auto out_type) {
auto out_type,
auto a_compute_type,
auto b_compute_type) {
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
......@@ -129,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using AComputeType = decltype(a_compute_type);
using BComputeType = decltype(b_compute_type);
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType>(
OutDataType,
AComputeType,
BComputeType>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
......@@ -146,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{});
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{});
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{});
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{});
return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{});
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{});
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{});
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{});
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{});
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{});
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{});
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{});
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
}
// NHWGC_GKYXC_NHWGK
......@@ -204,65 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{});
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{});
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F8_F8_F8)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{});
}
else if(data_type == ConvDataType::BF8_BF8_F8)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{});
}
else if(data_type == ConvDataType::F8_BF8_F8)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{});
}
}
......
......@@ -17,9 +17,9 @@ enum struct GemmMatrixLayout
enum struct GemmDataType
{
F16_F16_F16, // 0
BF16_INT8_BF16,// 1
BF16_BF16_BF16 // 2
F16_F16_F16, // 0
BF16_INT8_BF16, // 1
BF16_BF16_BF16 // 2
};
#define OP_NAME "grouped_gemm_two_stage"
......
......@@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_executable(${TEST_NAME} ${ARGN})
......@@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_executable(${TEST_NAME} ${ARGN})
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_gtest_executable(test_batched_gemm test_batched_gemm.cpp)
add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
set(target 1)
endif()
endforeach()
\ No newline at end of file
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_batched_gemm_gemm)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
set(target 1)
endif()
endif()
endforeach()
\ No newline at end of file
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp)
if(result EQUAL 0)
add_custom_target(test_batched_gemm_gemm)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance)
set(target 1)
endif()
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance)
endif()
endforeach()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_batched_gemm_softmax_gemm)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
set(target 1)
endif()
endif()
endforeach()
\ No newline at end of file
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp)
if(result EQUAL 0)
add_custom_target(test_batched_gemm_softmax_gemm)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
endif()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
endif()
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
endif()
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
endif()
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
endif()
set(target 1)
endif()
endforeach()
\ No newline at end of file
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
endif()
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
endif()
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
endif()
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment