Commit 057140b1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 134fc2e7 12a8883c
if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
\ No newline at end of file \ No newline at end of file
add_executable(client_groupnorm_swish groupnorm_swish.cpp) add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp)
target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations) target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations)
add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp)
target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp"
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DXDataType = float;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
ck::index_t N = 32;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t G = 64;
ck::index_t C = 128;
std::size_t length = N * H * W * G * C;
std::vector<ck::index_t> strideDy = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> strideX = strideDy;
std::vector<ck::index_t> strideDx = strideDy;
std::vector<ck::index_t> strideGamma = {0, 0, 0, C, 1};
std::vector<ck::index_t> strideMeanInvStd = {G, 0, 0, 1, 0};
SimpleDeviceMem dy_dev(sizeof(DYDataType) * length);
SimpleDeviceMem x_dev(sizeof(XDataType) * length);
SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * G * C);
SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G);
SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G);
SimpleDeviceMem dx_dev(sizeof(DXDataType) * length);
using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C},
strideDy,
strideX,
strideGamma,
strideMeanInvStd,
strideMeanInvStd,
strideDx,
{1, 2, 4}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte = sizeof(DYDataType) * length + sizeof(XDataType) * length +
sizeof(GammaDataType) * G * C +
sizeof(MeanInvStdDataType) * N * G * 2 +
sizeof(DXDataType) * length;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
// run the best intance
if(found)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C},
strideDy,
strideX,
strideGamma,
strideMeanInvStd,
strideMeanInvStd,
strideDx,
{1, 2, 4}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp)
target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_other_operations)
add_executable(client_max_pool2d_bwd max_pool2d_bwd.cpp) add_executable(client_max_pool2d_bwd max_pool2d_bwd.cpp)
target_link_libraries(client_max_pool2d_bwd PRIVATE composable_kernel::device_operations) target_link_libraries(client_max_pool2d_bwd PRIVATE composable_kernel::device_other_operations)
add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp)
target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_other_operations)
add_executable(client_avg_pool3d_bwd avg_pool3d_bwd.cpp) add_executable(client_avg_pool3d_bwd avg_pool3d_bwd.cpp)
target_link_libraries(client_avg_pool3d_bwd PRIVATE composable_kernel::device_operations) target_link_libraries(client_avg_pool3d_bwd PRIVATE composable_kernel::device_other_operations)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
endif() endif()
...@@ -191,6 +191,7 @@ int main(int argc, char* argv[]) ...@@ -191,6 +191,7 @@ int main(int argc, char* argv[])
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance // run the best intance
if(found)
{ {
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
......
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_image_to_column image_to_column.cpp) add_executable(client_image_to_column image_to_column.cpp)
target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_operations) target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_other_operations)
add_executable(client_column_to_image column_to_image.cpp) add_executable(client_column_to_image column_to_image.cpp)
target_link_libraries(client_column_to_image PRIVATE composable_kernel::device_operations) target_link_libraries(client_column_to_image PRIVATE composable_kernel::device_other_operations)
add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp) add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp)
target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_operations) target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_other_operations)
...@@ -117,6 +117,7 @@ int main() ...@@ -117,6 +117,7 @@ int main()
<< best_op_name << std::endl; << best_op_name << std::endl;
// run the best intance // run the best intance
if(found)
{ {
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
......
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_conv_operations)
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
using InLayout = ck::tensor_layout::convolution::NDHWGC; using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK; using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
...@@ -64,6 +65,9 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() ...@@ -64,6 +65,9 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo}; std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo};
std::array<ck::index_t, 6> out_strides{ std::array<ck::index_t, 6> out_strides{
K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
// Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW)
std::array<ck::index_t, 6> bias_lengths{G, 1, K, 1, 1, 1};
std::array<ck::index_t, 6> bias_strides{K, 0, 1, 0, 0, 0};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1}; std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1}; std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1};
...@@ -74,13 +78,13 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() ...@@ -74,13 +78,13 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K); SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * N * Do * Ho * Wo * G * K); SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
ck::Tuple<OutLayout, OutLayout>, ck::Tuple<OutLayout, BiasLayout>,
OutLayout, OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
...@@ -117,8 +121,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() ...@@ -117,8 +121,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
in_strides, in_strides,
wei_lengths, wei_lengths,
wei_strides, wei_strides,
{out_lengths, out_lengths}, {out_lengths, bias_lengths},
{out_strides, out_strides}, {out_strides, bias_strides},
out_lengths, out_lengths,
out_strides, out_strides,
filter_strides, filter_strides,
...@@ -187,8 +191,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() ...@@ -187,8 +191,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu()
in_strides, in_strides,
wei_lengths, wei_lengths,
wei_strides, wei_strides,
{out_lengths, out_lengths}, {out_lengths, bias_lengths},
{out_strides, out_strides}, {out_strides, bias_strides},
out_lengths, out_lengths,
out_strides, out_strides,
filter_strides, filter_strides,
......
add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 grouped_conv_fwd_scaleadd_ab_fp32.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 grouped_conv_fwd_scaleadd_ab_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 grouped_conv_fwd_scaleadd_ab_fp16.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 grouped_conv_fwd_scaleadd_ab_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 grouped_conv_fwd_scaleadd_ab_bf16.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 grouped_conv_fwd_scaleadd_ab_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 grouped_conv_fwd_scaleadd_ab_int8.cpp) add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 grouped_conv_fwd_scaleadd_ab_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_tensor_transform tensor_transform.cpp)
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
using DataType = int;
template <typename Desc>
void Print1d(const Desc& desc)
{
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < desc.GetLength(I0); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Desc>
void Print2d(const Desc& desc)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
}
template <typename Desc>
void Print3dCustom(const Desc& desc)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0; d < desc.GetLength(I0); d++)
{
for(ck::index_t h = 0; h < desc.GetLength(I1); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I2); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main()
{
// Tensor descriptor traverse in row-major (need to reverse dims)
std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,4)
const auto desc_4x8_s1x4 =
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}));
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(desc_4x8_s1x4);
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{});
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const auto desc_4x2x4_s2x1x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
// Transform to 2d (column-major, need to to reverse dims)
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
desc_4x2x4_s2x1x8,
ck::make_tuple(ck::make_pass_through_transform(4),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print2d(desc_4x2x4_s2x1x8_merged);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto desc_2x2x2x4_s1x4x2x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
// Transform to 2d
const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
// Transform to 3d
const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_pass_through_transform(2),
ck::make_pass_through_transform(2),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const auto desc_2x2x2x4_s1x4x2x8_nested =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_pass_through_transform(2),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))),
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d);
Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d);
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/wrapper/layout.hpp"
using DataType = int;
template <typename Layout>
void Print1d(const Layout& layout)
{
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++)
{
std::cout << layout(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Layout>
void Print2d(const Layout& layout)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++)
{
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
}
// Print in (x,y),z pattern
template <typename Layout>
void Print3dCustom(const Layout& layout)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++)
{
for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++)
{
for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++)
{
std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main()
{
// Layout traverse in row-major
std::cout << "Note: Layout traverse in column-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor)
// (dims:4,8 strides:1,4)
const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{});
const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8);
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(layout_4x8_s1x4);
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()<Cord1x1Type>();
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor)
// dims:4,(2,4) strides:2,(1,8)
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print2d(layout_4x2x4_s2x1x8);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}),
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}));
const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}),
ck::make_tuple(ck::Number<2>{}, ck::Number<8>{}));
static const auto layout_2x2x2x4_s1x4x2x8 =
ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8);
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print2d(layout_2x2x2x4_s1x4x2x8);
Print3dCustom(layout_2x2x2x4_s1x4x2x8);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const auto shape_2x2x2x4_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}),
ck::Number<4>{});
const auto strides_s1x4x2x8_nested = ck::make_tuple(
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}),
ck::Number<8>{});
static const auto layout_2x2x2x4_s1x4x2x8_nested =
ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested);
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(layout_2x2x2x4_s1x4x2x8_nested);
Print2d(layout_2x2x2x4_s1x4x2x8_nested);
Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested);
return 0;
}
...@@ -48,7 +48,7 @@ else() ...@@ -48,7 +48,7 @@ else()
endif() endif()
endif() endif()
find_package(composable_kernel COMPONENTS device_operations) find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
find_package(hip REQUIRED PATHS /opt/rocm) find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment