Commit 0277c89e authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into gemm_softmax

parents 3e811ccf abf4bdb9
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> #include <half.hpp>
#include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
...@@ -40,9 +42,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -40,9 +42,9 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmMNPadding = // static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization_t::MNPadding; // ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
...@@ -225,8 +227,7 @@ int main(int argc, char* argv[]) ...@@ -225,8 +227,7 @@ int main(int argc, char* argv[])
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
check_error(c_host_tensors[i], c_device_tensors[i]);
} }
} }
......
...@@ -40,7 +40,7 @@ using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; ...@@ -40,7 +40,7 @@ using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmSpecialization = static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization_t::Default; ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
......
# Instructions for ```convnd_bwd_data_xdl``` Example # Instructions for ```example_convnd_bwd_data_xdl```
## Docker script ## Run ```example_example_convnd_bwd_data_xdl```
```bash
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```convnd_bwd_data_xdl```
```bash
mkdir build && cd build
```
```bash
# Need to specify target ID, example below is gfx908
cmake \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
..
```
```bash
make -j convnd_bwd_data_xdl
```
## Run ```example_convnd_bwd_data_xdl```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
#arg4: num_dim_spatial(1|2|3) #arg4: num_dim_spatial(1|2|3)
#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx #arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx
./bin/convnd_bwd_data_xdl 0 1 5 ./bin/example_convnd_bwd_data_xdl 0 1 5
``` ```
Result Result
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "conv_utils.hpp" #include "conv_fwd_util.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -29,7 +29,7 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,7 +29,7 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDefault = static constexpr auto ConvBwdDefault =
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization_t::Default; ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
using DeviceConvBwdDataBasePtr = using DeviceConvBwdDataBasePtr =
ck::tensor_operation::device::DeviceConvBwdDataPtr<InElementOp, WeiElementOp, OutElementOp>; ck::tensor_operation::device::DeviceConvBwdDataPtr<InElementOp, WeiElementOp, OutElementOp>;
...@@ -44,7 +44,7 @@ using DeviceConvNDBwdDataInstance = ck::tensor_operation::device:: ...@@ -44,7 +44,7 @@ using DeviceConvNDBwdDataInstance = ck::tensor_operation::device::
InElementOp, // InElementwiseOperation InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation OutElementOp, // OutElementwiseOperation
ConvBwdDefault, // ConvolutionBackwardDataSpecialization_t ConvBwdDefault, // ConvolutionBackwardDataSpecialization
NumDimSpatial, // NumDimSpatial NumDimSpatial, // NumDimSpatial
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
...@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance = ...@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance =
OutElementOp, OutElementOp,
NumDimSpatial>; NumDimSpatial>;
void PrintUseMsg() void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
...@@ -99,10 +99,10 @@ void PrintUseMsg() ...@@ -99,10 +99,10 @@ void PrintUseMsg()
<< " <right padding>, (ie RightPy, RightPx for 2D)\n" << " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl; << std::endl;
} }
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{ {
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::conv_util::ConvParams params; ck::utils::conv::ConvParams params;
int arg_idx = 5; int arg_idx = 5;
params.num_dim_spatial = num_dim_spatial; params.num_dim_spatial = num_dim_spatial;
...@@ -144,73 +144,7 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) ...@@ -144,73 +144,7 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
return params; return params;
} }
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims, DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{});
}
case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{});
}
case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KZYXC{});
}
case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{});
}
case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KXC{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
switch(num_dim_spatial)
{
case 3: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWK{});
}
case 2: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{});
}
case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWK{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
}
DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial)
{ {
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
...@@ -236,7 +170,7 @@ int main(int argc, char* argv[]) ...@@ -236,7 +170,7 @@ int main(int argc, char* argv[])
int nrepeat = 5; int nrepeat = 5;
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::conv_util::ConvParams params; ck::utils::conv::ConvParams params;
params.C = 128; params.C = 128;
if(argc == 4) if(argc == 4)
...@@ -256,15 +190,15 @@ int main(int argc, char* argv[]) ...@@ -256,15 +190,15 @@ int main(int argc, char* argv[])
int cmdline_nargs = conv_args + 5; int cmdline_nargs = conv_args + 5;
if(cmdline_nargs != argc) if(cmdline_nargs != argc)
{ {
PrintUseMsg(); print_use_msg();
exit(1); exit(1);
} }
params = ParseConvParams(num_dim_spatial, argv); params = parse_conv_params(num_dim_spatial, argv);
} }
else if(argc != 1) else if(argc != 1)
{ {
PrintUseMsg(); print_use_msg();
exit(1); exit(1);
} }
...@@ -288,11 +222,13 @@ int main(int argc, char* argv[]) ...@@ -288,11 +222,13 @@ int main(int argc, char* argv[])
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result( Tensor<InDataType> in_n_c_hi_wi_host_result(
GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<InDataType> in_n_c_hi_wi_device_result( Tensor<InDataType> in_n_c_hi_wi_device_result(
GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); Tensor<WeiDataType> wei_k_c_y_x(
Tensor<OutDataType> out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
...@@ -318,11 +254,10 @@ int main(int argc, char* argv[]) ...@@ -318,11 +254,10 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero // reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0}); in_device_buf.SetZero();
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// do GEMM // do GEMM
auto conv = GetConvInstance(num_dim_spatial); auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer(); auto invoker = conv->MakeInvokerPointer();
auto argument = auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
...@@ -351,15 +286,15 @@ int main(int argc, char* argv[]) ...@@ -351,15 +286,15 @@ int main(int argc, char* argv[])
float ave_time = invoker->Run(argument.get(), nrepeat); float ave_time = invoker->Run(argument.get(), nrepeat);
std::size_t flop = ck::conv_util::GetFlops( std::size_t flop = ck::utils::conv::get_flops(
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
std::size_t num_btype = std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
ck::conv_util::GetBtype<InDataType, WeiDataType, OutDataType>(params.N, params.N,
params.C, params.C,
params.K, params.K,
params.input_spatial_lengths, params.input_spatial_lengths,
params.filter_spatial_lengths, params.filter_spatial_lengths,
output_spatial_lengths); output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
......
...@@ -40,7 +40,7 @@ using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; ...@@ -40,7 +40,7 @@ using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum;
using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmSpecialization = static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization_t::Default; ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle
......
...@@ -13,6 +13,7 @@ include_directories(BEFORE ...@@ -13,6 +13,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/utility
${PROJECT_SOURCE_DIR}/external/include/half ${PROJECT_SOURCE_DIR}/external/include/half
) )
...@@ -29,13 +30,11 @@ add_subdirectory(01_gemm) ...@@ -29,13 +30,11 @@ add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta) add_subdirectory(02_gemm_alpha_beta)
add_subdirectory(03_gemm_bias_relu) add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_bias_relu_add) add_subdirectory(04_gemm_bias_relu_add)
add_subdirectory(05_conv2d_fwd)
add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(06_conv2d_fwd_bias_relu)
add_subdirectory(07_conv2d_fwd_bias_relu_add) add_subdirectory(07_conv2d_fwd_bias_relu_add)
add_subdirectory(08_conv3d_fwd)
add_subdirectory(09_convnd_fwd) add_subdirectory(09_convnd_fwd)
add_subdirectory(10_conv2d_bwd_data) add_subdirectory(10_conv2d_bwd_data)
add_subdirectory(11_conv2d_bwd_wgt) add_subdirectory(11_conv2d_bwd_weight)
add_subdirectory(12_reduce) add_subdirectory(12_reduce)
add_subdirectory(13_pool2d_fwd) add_subdirectory(13_pool2d_fwd)
add_subdirectory(14_gemm_xdl_requant_relu_requant) add_subdirectory(14_gemm_xdl_requant_relu_requant)
......
...@@ -6,15 +6,9 @@ ...@@ -6,15 +6,9 @@
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#endif #endif
// "Constant" address space for kernel parameter // constant address space for kernel parameter
#define CONSTANT __attribute__((address_space(4))) // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
// GPU target
// should enable one and only one GPU target
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
#error Need to define (only) one GPU target
#endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 1
...@@ -24,155 +18,134 @@ ...@@ -24,155 +18,134 @@
#define CK_MIN_BLOCK_PER_CU 2 #define CK_MIN_BLOCK_PER_CU 2
#endif #endif
// GPU-specific parameters // check GPU target
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ #ifdef __HIP_DEVICE_COMPILE__
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
// buffer resourse defined(__gfx90a__) || defined(__gfx1030__))
#error Not supported target
#endif
#endif
// buffer resourse, wave size
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_GPU_WAVE_SIZE -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
// wave size
#define CK_GPU_WAVE_SIZE 64 #define CK_GPU_WAVE_SIZE 64
#elif defined(CK_AMD_GPU_GFX1030) #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32 #define CK_GPU_WAVE_SIZE 32
#endif #endif
// FMA instruction // FMA instruction
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32 #define CK_USE_AMD_V_MAC_F32
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \ #elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(CK_AMD_GPU_GFX1030) defined(__gfx1030__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8 #define CK_USE_AMD_V_DOT4_I32_I8
#endif #endif
// multi index // MFMA instruction
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
// AMD inline asm #elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#ifndef CK_USE_AMD_INLINE_ASM #define CK_USE_AMD_MFMA
#define CK_USE_AMD_INLINE_ASM 1
#endif #endif
// AMD inner product (DLOP) #if defined(__gfx90a__)
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM #define CK_USE_AMD_MFMA_BF16_1K_OP
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#endif #endif
// AMD buffer_load // buffer load
#ifndef CK_USE_AMD_BUFFER_LOAD
#define CK_USE_AMD_BUFFER_LOAD 1 #define CK_USE_AMD_BUFFER_LOAD 1
#endif
// AMD buffer_store // buffer store
#ifndef CK_USE_AMD_BUFFER_STORE
#define CK_USE_AMD_BUFFER_STORE 1 #define CK_USE_AMD_BUFFER_STORE 1
#endif
// AMD buffer_atomic_add // buffer atomic add: integer
#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD #define CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD 1
#endif
// AMD XDLOPS // buffer atomic add: floating point
#ifndef CK_USE_AMD_XDLOPS #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_XDLOPS 0 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif #endif
// inline asm
#define CK_USE_AMD_INLINE_ASM 1
// inner product (DLOP)
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
// experimental implementation for buffer load/store/atomic // experimental feature: multi index implemented as array
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #define CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK // experimental feature: static tensor descriptor
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK // experimental feature: buffer load/store/atomic-add OOB trick
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif
// experimental implementation for in-regsiter sub-dword transpose // experimental feature: in-regsiter sub-dword transpose
#ifndef CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 #define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
#endif
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// merge transformation use magic number division // experimental feature: merge transformation use magic number division
#ifndef CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
#endif
// use __builtin_memcpy instead of pointer cast to access a vector from pointer of scalar // experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS // pointer of scalar
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
// use __builtin_memcpy instead of union to do bit_cast // experimental feature: use __builtin_memcpy instead of union to do bit_cast
#ifndef CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#endif
// workaround for compiler crash when compiling recursive lambda // workaround: compiler crash when compiling recursive lambda
#ifndef CK_WORKAROUND_SWDEV_275126
#define CK_WORKAROUND_SWDEV_275126 1 #define CK_WORKAROUND_SWDEV_275126 1
#endif
// workaround for compiler crash when using buffer load/store for i8 // workaround: compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler gnerating inefficient ds_write instructions // workaround: compiler gnerating inefficient ds_write instructions
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
// workaround for register spill due to compiler issue, when casting type between fp32 and fp16
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1
#endif
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE // workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#endif
// workaround for verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter // tuning parameter
#ifndef CK_WORKAROUND_SWDEV_325164
#define CK_WORKAROUND_SWDEV_325164 1 #define CK_WORKAROUND_SWDEV_325164 1
#endif
// workaround for verification failure ConvNd forward // workaround for verification failure ConvNd forward
// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 // https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135
#ifndef CK_WORKAROUND_GITHUB_135
#define CK_WORKAROUND_GITHUB_135 1 #define CK_WORKAROUND_GITHUB_135 1
#endif
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum_t enum struct InMemoryDataOperationEnum
{ {
Set, Set,
AtomicAdd, AtomicAdd,
Add Add
}; };
enum struct ActivTypeEnum_t // TODO: no longer needed, remove this
enum struct ActivTypeEnum
{ {
None, None,
LeakyRelu, LeakyRelu,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace ck { namespace ck {
// StaticTensor for Scalar // StaticTensor for Scalar
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename T, typename T,
typename TensorDesc, typename TensorDesc,
bool InvalidElementUseNumericalZeroValue, bool InvalidElementUseNumericalZeroValue,
...@@ -80,7 +80,7 @@ struct StaticTensor ...@@ -80,7 +80,7 @@ struct StaticTensor
}; };
// StaticTensor for vector // StaticTensor for vector
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename S, typename S,
index_t ScalarPerVector, index_t ScalarPerVector,
typename TensorDesc, typename TensorDesc,
...@@ -245,7 +245,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -245,7 +245,7 @@ struct StaticTensorTupleOfVectorBuffer
S ignored_element_scalar_; S ignored_element_scalar_;
}; };
template <AddressSpaceEnum_t AddressSpace, template <AddressSpaceEnum AddressSpace,
typename T, typename T,
typename TensorDesc, typename TensorDesc,
typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
...@@ -255,7 +255,7 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc) ...@@ -255,7 +255,7 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc)
} }
template < template <
AddressSpaceEnum_t AddressSpace, AddressSpaceEnum AddressSpace,
typename T, typename T,
typename TensorDesc, typename TensorDesc,
typename X, typename X,
......
...@@ -207,9 +207,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -207,9 +207,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong"); "wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_k_m0_m1_thread_desc_.GetElementSpaceSize()); a_k_m0_m1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_k_n0_n1_thread_desc_.GetElementSpaceSize()); b_k_n0_n1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
......
...@@ -220,9 +220,9 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B ...@@ -220,9 +220,9 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong"); "wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction = constexpr auto threadwise_contraction =
......
...@@ -119,7 +119,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -119,7 +119,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{}; constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true> StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
......
...@@ -42,7 +42,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -42,7 +42,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum_t::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(), xdlops_gemm.GetRegSizePerXdlops(),
...@@ -250,9 +250,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -250,9 +250,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
......
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -14,7 +14,7 @@ namespace ck { ...@@ -14,7 +14,7 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. Run() does not construct new tensor coordinate // 3. Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -15,7 +15,7 @@ namespace ck { ...@@ -15,7 +15,7 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
......
...@@ -26,16 +26,20 @@ ...@@ -26,16 +26,20 @@
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP #ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP #define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
namespace ck { namespace ck {
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType, template <typename AccDataType,
index_t BlockSize, index_t BlockSize,
typename ThreadClusterLengths_M_K, typename ThreadClusterLengths_M_K,
...@@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction ...@@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>; using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType> template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData) __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{ {
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx = const auto thread_cluster_idx =
...@@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction ...@@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) { static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
...@@ -80,10 +91,10 @@ struct PartitionedBlockwiseReduction ...@@ -80,10 +91,10 @@ struct PartitionedBlockwiseReduction
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]); AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]); AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2); Accumulation::Calculate(opData1, opData2);
block_buffer(offset1) = type_convert<AccDataType>(opData1); work_buffer(offset1) = opData1;
} }
__syncthreads(); __syncthreads();
...@@ -91,10 +102,17 @@ struct PartitionedBlockwiseReduction ...@@ -91,10 +102,17 @@ struct PartitionedBlockwiseReduction
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_buffer[offset]); in_out_value = work_buffer[offset];
}; };
}; };
// clang-format off
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType, template <typename AccDataType,
typename IndexDataType, typename IndexDataType,
index_t BlockSize, index_t BlockSize,
...@@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex
// This interface accumulates on both data values and indices // This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType> template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce(BufferType& block_val_buffer, __device__ static void Reduce(BufferType& work_val_buffer,
IdxBufferType& block_idx_buffer, IdxBufferType& work_idx_buffer,
AccDataType& accuData, AccDataType& in_out_value,
IndexDataType& accuIndex) IndexDataType& in_out_index)
{ {
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
static_assert(is_same<typename IdxBufferType::type, IndexDataType>{},
"Buffer data type should be consistent as IndexDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx = const auto thread_cluster_idx =
...@@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) { static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I(); constexpr index_t indOffset = 1 << I();
...@@ -145,14 +173,14 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -145,14 +173,14 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]); AccDataType opData1 = work_val_buffer[offset1];
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]); AccDataType opData2 = work_val_buffer[offset2];
IndexDataType currIndex1 = block_idx_buffer[offset1]; IndexDataType currIndex1 = work_idx_buffer[offset1];
IndexDataType currIndex2 = block_idx_buffer[offset2]; IndexDataType currIndex2 = work_idx_buffer[offset2];
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2); Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
block_val_buffer(offset1) = type_convert<AccDataType>(opData1); work_val_buffer(offset1) = opData1;
block_idx_buffer(offset1) = currIndex1; work_idx_buffer(offset1) = currIndex1;
} }
__syncthreads(); __syncthreads();
...@@ -160,9 +188,9 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -160,9 +188,9 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_val_buffer[offset]); in_out_value = work_val_buffer[offset];
accuIndex = block_idx_buffer[offset]; in_out_index = work_idx_buffer[offset];
} };
}; };
}; // end of namespace ck }; // end of namespace ck
......
...@@ -5,7 +5,7 @@ namespace ck { ...@@ -5,7 +5,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum struct ConvolutionBackwardDataSpecialization_t enum struct ConvolutionBackwardDataSpecialization
{ {
Default, Default,
Filter1x1Stride1Pad0, Filter1x1Stride1Pad0,
......
...@@ -7,7 +7,7 @@ namespace ck { ...@@ -7,7 +7,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum struct ConvolutionForwardSpecialization_t enum struct ConvolutionForwardSpecialization
{ {
Default, Default,
Filter1x1Pad0, Filter1x1Pad0,
...@@ -15,14 +15,14 @@ enum struct ConvolutionForwardSpecialization_t ...@@ -15,14 +15,14 @@ enum struct ConvolutionForwardSpecialization_t
OddC, OddC,
}; };
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization_t& s) inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s)
{ {
switch(s) switch(s)
{ {
case ConvolutionForwardSpecialization_t::Default: return "Default"; case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization_t::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization_t::OddC: return "OddC"; case ConvolutionForwardSpecialization::OddC: return "OddC";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment