Unverified Commit e1a5137e authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into transpose_5d

parents eb57178d 718065eb
...@@ -30,6 +30,7 @@ using ADataType = int8_t; ...@@ -30,6 +30,7 @@ using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using CDataType = int32_t; using CDataType = int32_t;
using ComputeType = int8_t;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -43,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -43,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off // clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
// clang-format on // clang-format on
#include "run_splitK_gemm_example.inc" #include "run_splitK_gemm_example.inc"
......
...@@ -14,18 +14,22 @@ using ComputeDataType = float; ...@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct YElementOp struct YElementOp
{ {
template <typename T> template <typename Y, typename X>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value || static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
ck::is_same<T, ck::half_t>::value, ck::is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T a; static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
ck::is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
X a;
ck::tensor_operation::element_wise::Sigmoid{}(a, x); ck::tensor_operation::element_wise::Sigmoid{}(a, x);
y = x * a; y = ck::type_convert<Y>(x * a);
}; };
}; };
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" #include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -60,7 +60,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -60,7 +60,7 @@ bool maxpool_bwd_test(bool do_verification,
1>; // InSrcOutDstVectorSize 1>; // InSrcOutDstVectorSize
using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: using DeviceMaxPoolBwdInstance = ck::tensor_operation::device::
DeviceIndexPoolBwdImpl<DOutDataType, IndexDataType, DInDataType, 4>; DeviceMaxPoolBwdImpl<DOutDataType, IndexDataType, DInDataType, 4>;
const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
const ck::index_t Xs = (X - 1) * window_dilation_w + 1; const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
...@@ -155,7 +155,8 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -155,7 +155,8 @@ bool maxpool_bwd_test(bool do_verification,
dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), dout_n_c_ho_wo.mDesc.GetElementSpaceSize(),
din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(),
window_spatial_lengths, window_spatial_lengths,
window_strides); window_strides,
window_dilations);
if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get()))
{ {
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_image_to_column)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_dependencies(example_image_to_column example_image_to_column_f32)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static inline constexpr ck::index_t NDimSpatial = 2;
using FP32 = float;
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
};
#define DefaultConvParams \
ck::utils::conv::ConvParam \
{ \
NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ExecutionConfig& config,
ck::utils::conv::ConvParam& conv_params)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
config = ExecutionConfig{};
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using InDataType = FP32;
using OutDataType = FP32;
using InLayout = ck::tensor_layout::convolution::GNHWC;
// clang-format off
using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
< NDimSpatial, InLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>;
// clang-format on
bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params)
{
const auto N = conv_params.N_;
const auto C = conv_params.C_;
const ck::index_t NDoHoWo =
N * ck::accumulate_n<ck::index_t>(
conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const ck::index_t CZYX =
C * ck::accumulate_n<ck::index_t>(
conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const auto in_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_params);
const auto out_desc = HostTensorDescriptor({NDoHoWo, CZYX});
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_g_n_c_wis_strides{};
std::array<ck::index_t, 2> output_m_k_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(conv_params.input_spatial_lengths_, input_spatial_lengths);
copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths);
copy(conv_params.output_spatial_lengths_, output_spatial_lengths);
copy(in_desc.GetStrides(), input_g_n_c_wis_strides);
copy(out_desc.GetStrides(), output_m_k_strides);
copy(conv_params.conv_filter_strides_, conv_filter_strides);
copy(conv_params.conv_filter_dilations_, conv_filter_dilations);
copy(conv_params.input_left_pads_, input_left_pads);
copy(conv_params.input_right_pads_, input_right_pads);
Tensor<InDataType> in(in_desc);
Tensor<OutDataType> out_device(out_desc);
Tensor<OutDataType> out_host(out_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "out: " << out_device.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1: in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
default: in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
// reset input to zero
out_device_buf.SetZero();
static_assert(std::is_default_constructible_v<DeviceImgToColInstance>);
// do conv
auto img2col = DeviceImgToColInstance{};
auto invoker = img2col.MakeInvoker();
auto argument = img2col.MakeArgument(in_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_g_n_c_wis_strides,
output_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
if(!img2col.IsSupportedArgument(argument))
{
std::cerr << "wrong! device_img2col with the specified compilation parameters does "
"not support this img2col problem"
<< std::endl;
return false;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t num_btype = NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
if(config.do_verification)
{
auto ref_image_to_column = ck::tensor_operation::host::
ReferenceImageToColumn<NDimSpatial, InLayout, InDataType, OutDataType>();
auto ref_invoker = ref_image_to_column.MakeInvoker();
auto ref_argument = ref_image_to_column.MakeArgument(in,
out_host,
conv_params.filter_spatial_lengths_,
conv_params.conv_filter_strides_,
conv_params.conv_filter_dilations_,
conv_params.input_left_pads_,
conv_params.input_right_pads_);
if(!ref_image_to_column.IsSupportedArgument(&ref_argument))
{
std::cerr << "wrong! ref_img2col with the specified compilation parameters does "
"not support this img2col problem"
<< std::endl;
return false;
}
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device.mData, out_host.mData);
}
return true;
}
int RunImageToColumnExample(int argc, char* argv[])
{
ExecutionConfig config;
ck::utils::conv::ConvParam conv_params = DefaultConvParams;
if(!parse_cmd_args(argc, argv, config, conv_params))
{
return EXIT_FAILURE;
}
if(conv_params.num_dim_spatial_ != NDimSpatial)
{
std::cerr << "unsupported # of spatial dimensions" << std::endl;
return EXIT_FAILURE;
}
return !RunImageToColumn(config, conv_params);
}
int main(int argc, char* argv[]) { return RunImageToColumnExample(argc, argv); }
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#pragma once #pragma once
#include "ck/config.h"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -200,9 +202,6 @@ ...@@ -200,9 +202,6 @@
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_GITHUB_ISSUE_824 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_CONFIG_H_IN
#define CK_CONFIG_H_IN
// clang-format off
//
// DataType supports in the current CK build
//
#ifndef DTYPES
#cmakedefine DTYPES "@DTYPES@"
#endif
// if DTYPES is not defined, enable all datatypes in headerfiles
#ifndef CK_ENABLE_ALL_DTYPES
#cmakedefine CK_ENABLE_ALL_DTYPES @CK_ENABLE_ALL_DTYPES@
#if defined(CK_ENABLE_ALL_DTYPES)
#ifndef CK_ENABLE_INT8
#define CK_ENABLE_INT8 "ON"
#endif
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
#ifndef CK_ENABLE_BF16
#define CK_ENABLE_BF16 "ON"
#endif
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
#endif
#endif
// if DTYPES are selectively enabled
#ifndef CK_ENABLE_INT8
#cmakedefine CK_ENABLE_INT8 @CK_ENABLE_INT8@
#endif
#ifndef CK_ENABLE_FP8
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
#ifndef CK_ENABLE_BF16
#cmakedefine CK_ENABLE_BF16 @CK_ENABLE_BF16@
#endif
#ifndef CK_ENABLE_FP32
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif
#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif
//
// Legacy DL kernel supports in the current CK build
// by default DL kernels are turned OFF
//
#ifndef CK_ENABLE_DL_KERNELS
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
...@@ -1042,12 +1042,12 @@ struct Merge_v2_magic_division ...@@ -1042,12 +1042,12 @@ struct Merge_v2_magic_division
using UpLengths = using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsMagicDivisorMultipiler = decltype( using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{}, lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype( using LowLengthsMagicDivisorShift = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{}, lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
...@@ -1201,8 +1201,8 @@ struct Merge_v2r2_magic_division ...@@ -1201,8 +1201,8 @@ struct Merge_v2r2_magic_division
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{}, lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{},
Number<NDimLow>{})); Number<NDimLow>{}));
using LowLengthsScanMagicDivisorShift = decltype( using LowLengthsScanMagicDivisorShift = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{}, lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{},
Number<NDimLow>{})); Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
namespace ck {
/**
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
* the data loaded from LDS to registers.
*
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
* between them in such a way that threads from the same group need the same chunk of either
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
* whole chunk from LDS to its own register space.
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
* of the chunk, and then share that data with other threads from the same lane group.
*
* Assumptions coming from the usage of DPP8:
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
* matrix A or B;
* - based on these values we determine which matrix to share.
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
* - we have to make sure that the data to split is divisible by the number of
* threads in the group.
*
* General algorithm:
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
* A and B are visible to the whole block, C is distributed among each thread
* Assume:
* 1. A:
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
* 2. ABlockBuffer is DynamicBuffer
* 2. B:
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
* 2. BBlockBuffer is DynamicBuffer
* 3. C:
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
* 2. CThreadBuffer is StaticBuffer
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
*/
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
{
using AIndex = MultiIndex<4>;
using BIndex = MultiIndex<4>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t BM1 = BM100 * BM101 * BM11;
static constexpr index_t BN1 = BN100 * BN101 * BN11;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
// threads in a lane group need the same chunk of B or A matrices and we can share them using
// DPP.
static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size);
static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false;
static constexpr bool ShareA = !ShareB;
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
// respectively) elements, so we split them between threads in lane group so each thread loads
// less data from LDS.
static constexpr index_t BM1PerThread =
ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11;
static constexpr index_t BN1PerThread =
ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11;
__host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()},
b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()}
{
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
BM10BN10ThreadClusterBN10Xs::Size() == 2,
"wrong!");
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
__device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()
{
const auto offsetBM0 = c_thread_origin_data_idx_[I0];
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] +
dpp8::get_thread_idx_in_lane_group() * BM1PerThread
: c_thread_origin_data_idx_[I1];
return make_tuple(0, offsetBM0, offsetBM1, 0);
}
__device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()
{
const auto offsetBN0 = c_thread_origin_data_idx_[I2];
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] +
dpp8::get_thread_idx_in_lane_group() * BN1PerThread
: c_thread_origin_data_idx_[I3];
return make_tuple(0, offsetBN0, offsetBN1, 0);
}
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction =
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>,
ShareA>{};
static_for<0, BN0, 1>{}([&](auto bn0) {
static_for<0, BM0, 1>{}([&](auto bm0) {
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, bm0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, bn0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(bm0, I0, bn0, I0));
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, bm0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, bn0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(bm0, I0, bn0, I0));
});
});
});
}
private:
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThread>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThread>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThread, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThread, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
This diff is collapsed.
...@@ -221,7 +221,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -221,7 +221,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A // read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
...@@ -231,6 +235,45 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -231,6 +235,45 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
make_tuple(I0, m0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0),
b_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else
{
static_for<0, KPerBlock / WmmaK, 1>{}(
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
...@@ -239,6 +282,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -239,6 +282,15 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0),
b_thread_buf); b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
vector_type<FloatA, WmmaK> a_thread_vec; vector_type<FloatA, WmmaK> a_thread_vec;
vector_type<FloatB, WmmaK> b_thread_vec; vector_type<FloatB, WmmaK> b_thread_vec;
...@@ -265,6 +317,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -265,6 +317,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
}); });
}); });
} }
}
protected: protected:
// A[K0, M0, M1, M2, K1] // A[K0, M0, M1, M2, K1]
......
...@@ -4,27 +4,13 @@ ...@@ -4,27 +4,13 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck { namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1> template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment