Commit b6d3aa5d authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Merge branch 'gfx950' into andriy/lwpck-2430

parents a634647d b6f7cddd
...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p ...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
2, 2,
NHWGC, NHWGC,
GKYXC, GKYXC,
......
...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p ...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
2, 2,
NHWGC, NHWGC,
GKYXC, GKYXC,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi ...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
2, 2,
NHWGC, NHWGC,
GKYXC, GKYXC,
......
...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi ...@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
2, 2,
NHWGC, NHWGC,
GKYXC, GKYXC,
......
...@@ -15,6 +15,10 @@ set(GROUPED_CONV3D_BWD_WEIGHT ...@@ -15,6 +15,10 @@ set(GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp
) )
if(DL_KERNELS) if(DL_KERNELS)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1 ...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
3, 3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1 ...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
3, 3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 ...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
3, 3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 ...@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
3, 3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& ...@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
{ {
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
auto number_of_accumulations = 1;
static_assert(
ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX,
"Warning: Unhandled ReduceOpId for setting up the number of accumulations!");
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{
for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i)
{
number_of_accumulations *= kernel_params.window_spatial_lengths.at(i);
}
}
auto absolute_error_threshold = 1.0; auto absolute_error_threshold = 1.0;
switch(in_params.init_method) switch(in_params.init_method)
{ {
...@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& ...@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
absolute_error_threshold = absolute_error_threshold =
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>( ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
absolute_error_threshold); absolute_error_threshold, number_of_accumulations);
auto relative_error_threshold = auto relative_error_threshold =
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(); ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(
number_of_accumulations);
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
out_n_c_do_ho_wo_host.mData, out_n_c_do_ho_wo_host.mData,
......
...@@ -85,7 +85,7 @@ int profile_layernorm(int argc, char* argv[]) ...@@ -85,7 +85,7 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half) if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>( ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F16, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == ck::DataTypeEnum::Float)
......
...@@ -133,12 +133,12 @@ def parse_logfile(logfile): ...@@ -133,12 +133,12 @@ def parse_logfile(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
res.append(lst[4]) res.append(lst[4])
elif 'onnx_gemm' in logfile or 'mixed_gemm' in logfile: elif 'onnx_gemm' in logfile:
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
res.append(lst[33]) res.append(lst[33])
elif 'splitK_gemm' in logfile: elif 'splitK_gemm' in logfile or 'mixed_gemm' in logfile:
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
......
...@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log ...@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log
python3 process_perf_data.py perf_reduction.log python3 process_perf_data.py perf_reduction.log
python3 process_perf_data.py perf_splitK_gemm.log python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log python3 process_perf_data.py perf_onnx_gemm.log
python3 process_perf_data.py perf_mixed_gemm.log
file=./perf_fmha_fwd_gfx942.log file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then if [ -e "$file" ]; then
......
...@@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME) ...@@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl") if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
elseif(ARGN MATCHES "_smfmac") elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
...@@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME) ...@@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl") if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
elseif(ARGN MATCHES "_smfmac") elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
......
...@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadA = true; constexpr bool kPadM = true;
constexpr bool kPadB = true; constexpr bool kPadN = true;
constexpr bool kPadC = true; constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
...@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Lunching kernel with args:" std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl; << "}" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment