Commit d27e0691 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'upstream/develop' into merge_upstream_1129

also fix regression
parents 0a7174ad a2969aa8
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 32;
static constexpr ck::index_t N = 64; // batch size
static constexpr ck::index_t K = 64; // output channel
static constexpr ck::index_t C = 32; // input channel (per group)
static constexpr ck::index_t Z = 3; // filter D
static constexpr ck::index_t Y = 3; // filter H
static constexpr ck::index_t X = 3; // filter W
static constexpr ck::index_t Di = 14; // input D
static constexpr ck::index_t Hi = 14; // input H
static constexpr ck::index_t Wi = 14; // input W
static constexpr ck::index_t Do = 14; // output D
static constexpr ck::index_t Ho = 14; // output H
static constexpr ck::index_t Wo = 14; // output W
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int execute_conv_fwd_scaleadd_scaleadd_relu()
{
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space.
// However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW.
// Hence, we need to adjust the order of strides.
std::array<ck::index_t, 6> in_lengths{G, N, C, Di, Hi, Wi};
std::array<ck::index_t, 6> in_strides{
C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
std::array<ck::index_t, 6> wei_lengths{G, K, C, Z, Y, X};
std::array<ck::index_t, 6> wei_strides{
K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo};
std::array<ck::index_t, 6> out_strides{
K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K);
SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * N * Do * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<OutLayout, OutLayout>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<std::tuple_element_t<0, DDataTypes>, std::tuple_element_t<1, DDataTypes>>,
OutDataType,
PassThrough,
PassThrough,
ScaleAddScaleAddRelu>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{d0.GetDeviceBuffer(), d1.GetDeviceBuffer()},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{out_lengths, out_lengths},
{out_strides, out_strides},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ScaleAddScaleAddRelu{2.f, 2.f});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop =
std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 2 * N * Ho * Wo * G * K;
std::size_t num_bytes =
sizeof(InDataType) * N * Hi * Wi * G * C + sizeof(WeiDataType) * G * K * Y * X * C +
(sizeof(OutDataType) + sizeof(std::tuple_element_t<0, DDataTypes>) +
sizeof(std::tuple_element_t<1, DDataTypes>)) *
N * Ho * Wo * G * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return EXIT_FAILURE;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr =
op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{d0.GetDeviceBuffer(), d1.GetDeviceBuffer()},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{out_lengths, out_lengths},
{out_strides, out_strides},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ScaleAddScaleAddRelu{2.f, 2.f});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = ck::bhalf_t;
using WeiDataType = ck::bhalf_t;
using OutDataType = ck::bhalf_t;
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::bhalf_t, ck::bhalf_t>;
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<ck::half_t, ck::half_t>;
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = float;
using WeiDataType = float;
using OutDataType = float;
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<float, float>;
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = int8_t;
using WeiDataType = int8_t;
using OutDataType = int8_t;
// Use std tuple instead of ck tuple to avoid clang
// implicit instantiation of undefined template error.
using DDataTypes = std::tuple<float, float>;
#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc"
int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); }
add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 grouped_conv_fwd_scaleadd_ab_fp32.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 grouped_conv_fwd_scaleadd_ab_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 grouped_conv_fwd_scaleadd_ab_bf16.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 grouped_conv_fwd_scaleadd_ab_int8.cpp)
target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_conv_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 32;
static constexpr ck::index_t N = 64; // batch size
static constexpr ck::index_t K = 64; // output channel
static constexpr ck::index_t C = 32; // input channel (per group)
static constexpr ck::index_t Z = 3; // filter D
static constexpr ck::index_t Y = 3; // filter H
static constexpr ck::index_t X = 3; // filter W
static constexpr ck::index_t Di = 14; // input D
static constexpr ck::index_t Hi = 14; // input H
static constexpr ck::index_t Wi = 14; // input W
static constexpr ck::index_t Do = 14; // output D
static constexpr ck::index_t Ho = 14; // output H
static constexpr ck::index_t Wo = 14; // output W
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int execute_conv_fwd_scaleadd_ab()
{
constexpr ck::index_t NumAs = 2;
constexpr ck::index_t NumBs = 2;
constexpr float scale = 1.5f;
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space.
// However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW.
// Hence, we need to adjust the order of strides.
std::array<ck::index_t, 6> in_lengths{G, N, C, Di, Hi, Wi};
std::array<ck::index_t, 6> in_strides{
C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
std::array<ck::index_t, 6> wei_lengths{G, K, C, Z, Y, X};
std::array<ck::index_t, 6> wei_strides{
K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
std::array<ck::index_t, 6> out_lengths{G, N, K, Do, Ho, Wo};
std::array<ck::index_t, 6> out_strides{
K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
using InputDtype = ck::tuple_element_t<0, InDataType>;
using InputBiasDtype = ck::tuple_element_t<1, InDataType>;
using WeightDtype = ck::tuple_element_t<0, WeiDataType>;
using WeightBiasDtype = ck::tuple_element_t<1, WeiDataType>;
SimpleDeviceMem in(sizeof(InputDtype) * N * Di * Hi * Wi * G * C);
SimpleDeviceMem in_bias(sizeof(InputBiasDtype) * N * Di * Hi * Wi * G * C);
SimpleDeviceMem wei(sizeof(WeightDtype) * G * K * Z * Y * X * C);
SimpleDeviceMem wei_bias(sizeof(WeightBiasDtype) * G * K * Z * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
ScaleAdd,
ScaleAdd,
PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::array<const void*, NumAs> as = {in.GetDeviceBuffer(), in_bias.GetDeviceBuffer()};
std::array<const void*, NumBs> bs = {wei.GetDeviceBuffer(), wei_bias.GetDeviceBuffer()};
std::array<const void*, 0> ds{};
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(as,
bs,
ds,
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
ScaleAdd{scale},
ScaleAdd{scale},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Z * Y * X +
N * Di * Hi * Wi * G * C + G * K * Z * Y * X * C;
std::size_t num_bytes = 2 * sizeof(InDataType) * N * Di * Hi * Wi * G * C +
2 * sizeof(WeiDataType) * G * K * Z * Y * X * C +
sizeof(OutDataType) * N * Do * Ho * Wo * G * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return EXIT_FAILURE;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(as,
bs,
ds,
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
ScaleAdd{scale},
ScaleAdd{scale},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = ck::Tuple<ck::bhalf_t, ck::bhalf_t>;
using WeiDataType = ck::Tuple<ck::bhalf_t, ck::bhalf_t>;
using OutDataType = ck::bhalf_t;
#include "grouped_conv_fwd_scaleadd_ab.inc"
int main() { return execute_conv_fwd_scaleadd_ab(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = ck::Tuple<ck::half_t, ck::half_t>;
using WeiDataType = ck::Tuple<ck::half_t, ck::half_t>;
using OutDataType = ck::half_t;
#include "grouped_conv_fwd_scaleadd_ab.inc"
int main() { return execute_conv_fwd_scaleadd_ab(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = ck::Tuple<float, float>;
using WeiDataType = ck::Tuple<float, float>;
using OutDataType = float;
#include "grouped_conv_fwd_scaleadd_ab.inc"
int main() { return execute_conv_fwd_scaleadd_ab(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
using InDataType = ck::Tuple<int8_t, int8_t>;
using WeiDataType = ck::Tuple<int8_t, int8_t>;
using OutDataType = int8_t;
#include "grouped_conv_fwd_scaleadd_ab.inc"
int main() { return execute_conv_fwd_scaleadd_ab(); }
...@@ -48,7 +48,7 @@ else() ...@@ -48,7 +48,7 @@ else()
endif() endif()
endif() endif()
find_package(composable_kernel COMPONENTS device_operations) find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations)
find_package(hip REQUIRED PATHS /opt/rocm) find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}") message(STATUS "Build with HIP ${hip_VERSION}")
......
...@@ -309,6 +309,8 @@ XML_OUTPUT ...@@ -309,6 +309,8 @@ XML_OUTPUT
XML_PROGRAMLISTING XML_PROGRAMLISTING
) )
set(WARN_AS_ERROR YES)
set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file")
function(add_doxygen_doc) function(add_doxygen_doc)
......
...@@ -70,6 +70,7 @@ else() ...@@ -70,6 +70,7 @@ else()
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
-Wno-unused-template
) )
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
list(APPEND CMAKE_COMPILER_WARNINGS list(APPEND CMAKE_COMPILER_WARNINGS
......
...@@ -2,7 +2,101 @@ ...@@ -2,7 +2,101 @@
Contributor's Guide Contributor's Guide
=================== ===================
Pull-request guidelines This chapter explains how to get started contributing to the Composable Kernel project and what are
======================= the contributing rules.
[TODO] Getting started
===============
#. **Documentation:** Before contributing to the library, familiarize yourself with the
`Composable Kernel User Guide <https://rocm.docs.amd.com/projects/composable_kernel/en/latest/>`_.
It provides insight into the core concepts, environment configuration, and steps to obtain or
build the library. You can also find some of this information in the
`README file <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/README.md>`_
on the project's GitHub page.
#. **Additional reading:** We also recommend reading a `blog post
<https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_
from the AMD Community portal. It offers a deeper understanding of the library's objectives and
showcases its performance capabilities.
#. **General information:** For broader information about AMD products, consider exploring the
`AMD Developer Central portal <https://www.amd.com/en/developer.html>`_.
How do I contribute
===================
We deeply value contributions from our users. You can make an impact by reporting issues or
proposing code enhancements through pull requests.
Reporting issues
----------------
We use `Github issues <https://github.com/ROCmSoftwarePlatform/composable_kernel/issues>`_
to track public bugs and enhancement requests.
If you encounter an issue with the library, please check if the problem has already been
reported by searching existing issues on GitHub. If your issue seems unique, please submit a new
issue. All reported issues must include:
* A comprehensive description of the problem, including:
* What did you observe?
* Why do you think it is a bug (if it seems like one)?
* What did you expect to happen? What would indicate the resolution of the problem?
* Are there any known workarounds?
* Your configuration details, including:
* Which GPU are you using?
* Which OS version are you on?
* Which ROCm version are you using?
* Are you using a Docker image? If so, which one?
* Steps to reproduce the issue, including:
* What actions trigger the issue? What are the reproduction steps?
* If you build the library from scratch, what CMake command did you use?
* How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue?
Before sumbitting any issue, ensure you have addressed all relevant questions from the checklist.
Creating Pull Requests
----------------------
You can submit `Pull Requests (PR) on GitHub
<https://github.com/ROCmSoftwarePlatform/composable_kernel/pulls>`_.
All contributors are required to develop their changes on a separate branch and then create a
pull requrest to merge their changes into the `develop` branch, which is the default
development branch in the Composable Kernel project. All external contributors must use their own
forks of the project to develop their changes.
When submitting a Pull Request you should:
* Describe the change providing information about the motivation for the change and a general
description of all code modifications.
* Verify and test the change:
* Run any relevant existing tests.
* Write new tests if added functionality is not covered by current tests.
* Ensure your changes align with the coding style defined in the ``.clang-format`` file located in
the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We
highly recommend contributors utilize this method to maintain consistent code formatting.
Instructions on setting up `pre-commit` can be found in the project's
`README file <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/README.md>`_
* Link your PR to any related issues:
* If there is an issue that is resolved by your change, please provide a link to the issue in
the description of your pull request.
* For larger contributions, structure your change into a sequence of smaller, focused commits, each
addressing a particular aspect or fix.
Following the above guidelines ensures a seamless review process and faster assistance from our
end.
Thank you for your commitment to enhancing the Composable Kernel project! We look forward to collaborating with you.
rocm-docs-core>=0.20.0 rocm-docs-core>=0.20.0
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.6.1
...@@ -42,12 +42,18 @@ fastjsonschema==2.18.0 ...@@ -42,12 +42,18 @@ fastjsonschema==2.18.0
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.10
# via gitpython # via gitpython
gitpython==3.1.31 gitpython==3.1.35
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.4
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
importlib-metadata==6.8.0
# via
# sphinx
# sphinxcontrib-bibtex
importlib-resources==6.1.0
# via rocm-docs-core
jinja2==3.1.2 jinja2==3.1.2
# via # via
# myst-parser # myst-parser
...@@ -90,9 +96,13 @@ pygments==2.14.0 ...@@ -90,9 +96,13 @@ pygments==2.14.0
# pydata-sphinx-theme # pydata-sphinx-theme
# sphinx # sphinx
pyjwt[crypto]==2.6.0 pyjwt[crypto]==2.6.0
# via pygithub # via
# pygithub
# pyjwt
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pytz==2023.3.post1
# via babel
pyyaml==6.0 pyyaml==6.0
# via # via
# myst-parser # myst-parser
...@@ -103,7 +113,7 @@ requests==2.28.2 ...@@ -103,7 +113,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core>=0.20.0 rocm-docs-core==0.27.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
...@@ -139,7 +149,7 @@ sphinx-notfound-page==0.8.3 ...@@ -139,7 +149,7 @@ sphinx-notfound-page==0.8.3
# via rocm-docs-core # via rocm-docs-core
sphinxcontrib-applehelp==1.0.4 sphinxcontrib-applehelp==1.0.4
# via sphinx # via sphinx
sphinxcontrib-bibtex==2.5.0 sphinxcontrib-bibtex==2.6.1
# via -r requirements.in # via -r requirements.in
sphinxcontrib-devhelp==1.0.2 sphinxcontrib-devhelp==1.0.2
# via sphinx # via sphinx
...@@ -157,3 +167,7 @@ urllib3==1.26.15 ...@@ -157,3 +167,7 @@ urllib3==1.26.15
# via requests # via requests
wrapt==1.15.0 wrapt==1.15.0
# via deprecated # via deprecated
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources
if(DL_KERNELS) add_custom_target(example_gemm_dl)
add_custom_target(example_gemm_dl)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) add_example_dependencies(example_gemm_dl example_gemm_dl_fp32)
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_example_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16)
endif() add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_dependencies(example_gemm_dl example_gemm_dl_int8)
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) if(USE_BITINT_EXTENSION_INT4)
add_dependencies(example_gemm_dl example_gemm_dl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
add_dependencies(example_gemm_dl example_gemm_dl_int4) add_example_dependencies(example_gemm_dl example_gemm_dl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
endif()
add_custom_target(example_gemm_xdl) add_custom_target(example_gemm_xdl)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
endif()
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_int4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) # FIXME: re-enable this example as test when SWDEV-335738 is fixed
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
endif()
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_f8) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
endif() add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
endif()
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
...@@ -3,31 +3,33 @@ ...@@ -3,31 +3,33 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp"
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using ALayout = Col; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>;
// clang-format on // // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -30,7 +30,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl ...@@ -30,7 +30,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>;
// // clang-format on // // clang-format on
// clang-format off // clang-format off
...@@ -39,9 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl ...@@ -39,9 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>;
ck::make_default_loop_scheduler(),
ck::PipelineVersion::v2>;
// clang-format on // clang-format on
using DeviceGemmInstance = DeviceGemmInstance1; using DeviceGemmInstance = DeviceGemmInstance1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment