"docs/vscode:/vscode.git/clone" did not exist on "275ddade66d1b4da9733a339b094f916a66f6da9"
Commit d22dbec2 authored by zhoux's avatar zhoux
Browse files

Initial commit: release hytlass-0.1.0

parents
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
set(HYTLASS_EXAMPLES_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common)
add_custom_target(hytlass_examples)
add_custom_target(test_examples)
function(hytlass_example_add_executable NAME)
set(options)
set(oneValueArgs DISABLE_TESTS)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED __DISABLE_TESTS)
set(__DISABLE_TESTS OFF)
endif()
hytlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS} BATCH_SOURCES OFF)
add_dependencies(hytlass_examples ${NAME})
target_link_libraries(
${NAME}
PRIVATE
HYTLASS
hytlass_tools_util_includes
$<$<BOOL:${HYTLASS_ENABLE_HIPBLAS}>:hip::hipblas>
hip
)
target_include_directories(
${NAME}
PRIVATE
${HYTLASS_EXAMPLES_COMMON_SOURCE_DIR}
${HYTLASS_EXAMPLES_UTILS_DIR}
)
install(
TARGETS ${NAME}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
)
hytlass_add_executable_tests(
test_examples_${NAME} ${NAME}
DEPENDS ${__DEPENDS}
DEPENDEES test_examples ${__DEPENDEES}
TEST_COMMAND_OPTIONS ${__TEST_COMMAND_OPTIONS}
DISABLE_EXECUTABLE_INSTALL_RULE
DISABLE_TESTS ${__DISABLE_TESTS}
)
endfunction()
foreach(EXAMPLE
00_hytlass_basic_gemm
01_hytlass_serial_splitk_gemm
02_hytlass_parallel_splitk_gemm
03_hytlass_streamk_gemm
04_hytlass_batch_gemm
05_hytlass_group_gemm
06_hute_streamk
07_hute_batch_gemm
08_hytlass_fused_gemm
09_hytlass_tensorop_conv2d
10_hytlass_tensorop_wgrad_split_k
11_hytlass_tensorop_group_conv
12_depthwise_simt_conv2dfprop
13_hytlass_tensorop_fused_conv2d_fprop
14_gather_scatter_fusion
15_hute_group_gemm
16_hytlass_gemm_softmax
17_ell_block_sparse_gemm
18_gemm_with_abs_max
hute
)
add_subdirectory(${EXAMPLE})
endforeach()
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "hute/layout.hpp"
#include "hute/tensor.hpp"
#include "hute/util/print.hpp"
namespace example {
using namespace hute;
// Empty type used to disable gather/scatter for a GEMM argument
struct NoGather
{
template<class... Ts>
NoGather(Ts...) {};
};
/// Function object that applies an index to its argument
template <class Index>
struct IndexedGather
{
HUTE_HOST_DEVICE constexpr
IndexedGather(Index const *indices = {}): indices_(indices) {}
template <typename I>
HUTE_HOST_DEVICE constexpr
Index
operator()(I i) const { return indices_[i]; }
HUTE_HOST_DEVICE friend
void
print(IndexedGather const &s) {
hute::print("Indexed");
}
Index const *indices_;
};
/// Function object that applies a stride to its argument
/// Example: StridedFunc<int,_2> gathers every other row/column
template <class Stride>
struct StridedGather
{
HUTE_HOST_DEVICE constexpr
StridedGather(Stride stride = {}): stride_(stride) {}
template <class I>
HUTE_HOST_DEVICE constexpr
auto
operator()(I i) const { return i * stride_; }
HUTE_HOST_DEVICE friend
void
print(StridedGather const &s) {
hute::print("Strided{");
print(s.stride_);
hute::print("}");
}
Stride stride_;
};
/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride
{
HUTE_HOST_DEVICE constexpr
CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {}
template <class I>
HUTE_HOST_DEVICE constexpr friend
auto
operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; }
template <class I>
HUTE_HOST_DEVICE constexpr friend
auto
operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; }
HUTE_HOST_DEVICE friend
void
print(CustomStride const & s) {
hute::print("Custom{");
print(s.func_);
hute::print(",");
print(s.stride_);
hute::print("}");
}
template<class Div>
HUTE_HOST_DEVICE constexpr friend
auto
safe_div(CustomStride const &s, Div const &div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
}
// Circumvent the requirement on make_layout that shape and stride are integral
template <class Shape>
HUTE_HOST_DEVICE constexpr friend
auto
make_layout(Shape const &shape, CustomStride const &stride)
{
return Layout<Shape, CustomStride>(shape, stride);
}
Func func_;
Stride stride_;
};
template<class Stride, class Func>
HYTLASS_HOST_DEVICE
auto
make_custom_stride_layout(Stride const &stride, Func&& func)
{
// Use a dummy shape and replace the first non-unit stride with a custom gather stride
auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
return make_layout(repeat_like(stride, _1{}),
replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}
/// Helper function to optionally create a gather tensor
template<class Iterator, class Shape, class Stride, class Func>
HYTLASS_HOST_DEVICE
auto
make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)
{
if constexpr (not hytlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
Layout matrix_layout = make_identity_layout(shape);
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
} else {
return make_tensor(iter, shape, stride);
}
}
} // namespace example
namespace hute
{
template<int N, int I, class Shape, class Stride>
HUTE_HOST_DEVICE constexpr
auto
upcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value) {
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });
} else if constexpr (is_scaled_basis<Stride>::value) {
if constexpr (Stride::mode() == I) {
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
} else {
return make_layout(shape, stride);
}
} else {
return upcast<N>(shape, stride);
}
HUTE_GCC_UNREACHABLE;
}
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
HUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<Layout<OuterShape,OuterStride>,Offset,Layout<Shape,Stride>> const& layout)
{
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
// Upcast the outer layout (works as expected)
auto outer = upcast<N>(layout.layout_a());
// Upcast the accumulated offset along stride-1 mode
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
// Upcast the inner layout's shape along stride-1 mode
auto inner = upcast<N,I>(layout.layout_b().shape(), layout.layout_b().stride());
return composition(outer, offset, inner);
}
} // namespace example
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass_unit_test.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/distribution.h"
#include "hytlass/util/packed_stride.hpp"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_norm.h"
#include "hytlass/util/reference/host/gett.hpp"
#include "hytlass/util/GPU_Clock.hpp"
#include "testbed_utils.h"
#include "hytlass/kernel_hardware_info.hpp"
#include "hytlass/layout/matrix.h"
#include "hytlass/matrix_coord.h"
#include "hytlass/gemm/gemm.h"
#include "hytlass/epilogue/fusion/operations.hpp"
#include "hute/int_tuple.hpp"
namespace test {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail{
// Helper classes that take default data type when
// the Gemm::EpilogueOutputOp does not have ElementCompute
// and ElementScalar.
template <typename Gemm, typename Default, typename = void>
struct ElementComputeType {
using Type = Default;
};
template <typename Gemm, typename Default>
struct ElementComputeType<Gemm, Default, std::void_t<typename Gemm::EpilogueOutputOp::ElementCompute>> {
using Type = typename Gemm::EpilogueOutputOp::ElementCompute;
};
template <typename Gemm, typename Default, typename = void>
struct ElementScalarType {
using Type = Default;
};
template <typename Gemm, typename Default>
struct ElementScalarType<Gemm, Default, std::void_t<typename Gemm::EpilogueOutputOp::ElementScalar>> {
using Type = typename Gemm::EpilogueOutputOp::ElementScalar;
};
// The number of splits to test.
//
// This class makes it harder to confuse the order of arguments
// of the various run(...) functions in this file. The constructor
// is explicit, so one can't just type 42 (or false, which the
// compiler unhelpfully turns into 0); one has to type Splits(42).
// Splits() picks the default number of splits, 1.
//
// The conversion-to-int operator (operator int()) MUST be explicit!
// Conversion to int MUST require static_cast<int>.
// Otherwise, that defeats a key purpose of this class,
// which is to catch common errors of confusing the order
// of function arguments.
class Splits {
public:
Splits() = default;
template<class IntegralNotBool,
__HUTE_REQUIRES((std::is_integral_v<IntegralNotBool> &&
!std::is_same_v<IntegralNotBool, bool>)) >
explicit Splits(IntegralNotBool splits) : splits_(splits) {}
explicit operator int() const { return splits_; }
private:
int splits_ = 1;
};
// The number of iterations to test.
//
// This class, like Splits above makes it harder to confuse
// the order of arguments of the various run(...) functions in this file.
// Iterations() picks the default number of iterations, 20.
class Iterations {
public:
Iterations() = default;
template<class IntegralNotBool,
__HUTE_REQUIRES((std::is_integral_v<IntegralNotBool> &&
!std::is_same_v<IntegralNotBool, bool>)) >
explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {}
explicit operator int() const { return iterations_; }
private:
int iterations_ = 20;
};
template <
typename Gemm,
template <class T> class ActivationFunctor_ = hytlass::epilogue::thread::Identity
>
struct TestbedImpl {
// Kernel data types
using ElementA = typename Gemm::GemmKernel::ElementA;
using StrideA = typename Gemm::GemmKernel::StrideA;
using ElementB = typename Gemm::GemmKernel::ElementB;
using StrideB = typename Gemm::GemmKernel::StrideB;
using ElementC = std::conditional_t<std::is_void_v<typename Gemm::GemmKernel::ElementC>,
typename Gemm::GemmKernel::ElementD,typename Gemm::GemmKernel::ElementC>;
using StrideC = typename Gemm::GemmKernel::StrideC;
using ElementD = typename Gemm::GemmKernel::ElementD;
using StrideD = typename Gemm::GemmKernel::StrideD;
using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
/// For custom EVTs
using ElementCompute = typename ElementComputeType<Gemm, ElementAccumulator>::Type;
using ElementScalar = typename ElementScalarType<Gemm, ElementCompute>::Type;
using ActivationFunctor = ActivationFunctor_<ElementCompute>;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static constexpr uint32_t mma_promotion_interval = 4;
// Looks at Hute Stride to check Row / Column Major
template<typename Stride>
static constexpr bool is_row_or_col_major(){
int stride_0 = int(hute::size<0>(Stride{}));
int stride_1 = int(hute::size<1>(Stride{}));
int depth = hute::depth(Stride{});
return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1);
}
// Note: this limitation comes from testbed / not the library
static_assert(is_row_or_col_major<StrideA>(),
"ERROR : A Layout is neither Row / Column Major)");
static_assert(is_row_or_col_major<StrideB>(),
"ERROR : B Layout is neither Row / Column Major)");
static_assert(is_row_or_col_major<StrideC>(),
"ERROR : C Layout is neither Row / Column Major)");
static_assert(is_row_or_col_major<StrideD>(),
"ERROR : D Layout is neither Row / Column Major)");
// Deduce Hytlass Layouts (RowMajor & ColumnMajor)
using LayoutTagA = hytlass::detail::StrideToLayoutTagA_t<StrideA>;
using LayoutTagB = hytlass::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = hytlass::detail::StrideToLayoutTagA_t<StrideC>;
using LayoutTagD = hytlass::detail::StrideToLayoutTagA_t<StrideD>;
/// Initialization
StrideA stride_a;
StrideB stride_b;
StrideC stride_c;
StrideD stride_d;
typename LayoutTagA::Stride stride_factor_A;
typename LayoutTagB::Stride stride_factor_B;
typename LayoutTagC::Stride stride_factor_C;
typename LayoutTagD::Stride stride_factor_D;
hytlass::Distribution::Kind init_A;
hytlass::Distribution::Kind init_B;
hytlass::Distribution::Kind init_C;
uint64_t seed;
static constexpr uint64_t kDefaultSeed = 4096;
hytlass::HostTensor<ElementA, LayoutTagA> tensor_A;
hytlass::HostTensor<ElementB, LayoutTagB> tensor_B;
hytlass::HostTensor<ElementC, LayoutTagC> tensor_C;
hytlass::HostTensor<ElementD, LayoutTagD> tensor_D;
hytlass::HostTensor<ElementD, LayoutTagD> reference_D;
uint32_t sm_count;
// Used to force multi-wave tests for persistent kernel schedules
constexpr static int MaxSmCount = 16;
//
// Methods
//
TestbedImpl(
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint64_t seed_ = kDefaultSeed
):
stride_factor_A(typename LayoutTagA::Stride()),
stride_factor_B(typename LayoutTagB::Stride()),
stride_factor_C(typename LayoutTagC::Stride()),
stride_factor_D(typename LayoutTagD::Stride()),
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
TestbedImpl(
typename LayoutTagA::Stride stride_factor_A_,
typename LayoutTagB::Stride stride_factor_B_,
typename LayoutTagC::Stride stride_factor_C_,
typename LayoutTagD::Stride stride_factor_D_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint64_t seed_ = kDefaultSeed
):
stride_factor_A(stride_factor_A_),
stride_factor_B(stride_factor_B_),
stride_factor_C(stride_factor_C_),
stride_factor_D(stride_factor_D_),
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
hytlass::TensorView<Element, Layout> view,
hytlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == hytlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
int bits_output = hytlass::sizeof_bits<ElementD>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
hytlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == hytlass::Distribution::Identity) {
hytlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == hytlass::Distribution::Gaussian) {
hytlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == hytlass::Distribution::Sequential) {
hytlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else if (dist_kind == hytlass::Distribution::AllOnes) {
hytlass::reference::host::TensorFill(view, Element(1));
}
else {
return false;
}
return true;
}
/// Initializes data structures, this is batch Specialization
void initialize(ProblemShapeType problem_size) {
//
// Allocate the GEMM workspace
//
auto problem_shape_MNKL = hute::append<4>(problem_size, 1);
auto M = hute::size<0>(problem_shape_MNKL);
auto N = hute::size<1>(problem_shape_MNKL);
auto K = hute::size<2>(problem_shape_MNKL);
auto L = hute::size<3>(problem_shape_MNKL);
stride_a = hytlass::make_hute_packed_stride(StrideA{}, hute::make_shape(M, K, L));
stride_b = hytlass::make_hute_packed_stride(StrideB{}, hute::make_shape(N, K, L));
stride_c = hytlass::make_hute_packed_stride(StrideC{}, hute::make_shape(M, N, L));
stride_d = hytlass::make_hute_packed_stride(StrideD{}, hute::make_shape(M, N, L));
// 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode
auto a_coord = hytlass::make_Coord(M * L, K);
auto c_coord = hytlass::make_Coord(M * L, N);
// Hytlass has Row/Col major refers to MxK times KxN matrix product,
// so the HostTensorB should be treated as KxN in "coord"'s view
auto b_coord = hytlass::make_Coord(K, N * L);
tensor_A.resize(a_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_B.resize(b_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
tensor_C.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
tensor_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
reference_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D), false);
(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022));
(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021));
(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020));
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
tensor_A.host_view().at({0, 0}) = ElementA(1);
tensor_B.host_view().at({0, 0}) = ElementB(1);
tensor_C.host_view().at({0, 0}) = ElementC(1);
hytlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
}
/// Initializes data structures, this is splitk Specialization
void initialize(ProblemShapeType problem_size, int slice_k) {
//
// Allocate the GEMM workspace
//
auto problem_shape_MNKL = hute::append<4>(problem_size, 1);
auto M = hute::size<0>(problem_shape_MNKL);
auto N = hute::size<1>(problem_shape_MNKL);
auto K = hute::size<2>(problem_shape_MNKL);
auto L = 1; // 由于L维度存的是slice_k,splitk与batch不能共存,因此batch count只能为1
stride_a = hytlass::make_hute_packed_stride(StrideA{}, hute::make_shape(M, K, L));
stride_b = hytlass::make_hute_packed_stride(StrideB{}, hute::make_shape(N, K, L));
stride_c = hytlass::make_hute_packed_stride(StrideC{}, hute::make_shape(M, N, L));
stride_d = hytlass::make_hute_packed_stride(StrideD{}, hute::make_shape(M, N, L));
// 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode
auto a_coord = hytlass::make_Coord(M * L, K);
auto c_coord = hytlass::make_Coord(M * L, N);
// Hytlass has Row/Col major refers to MxK times KxN matrix product,
// so the HostTensorB should be treated as KxN in "coord"'s view
auto b_coord = hytlass::make_Coord(K, N * L);
tensor_A.resize(a_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_B.resize(b_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
tensor_C.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
tensor_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
reference_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D), false);
(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022));
(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021));
(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020));
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
tensor_A.host_view().at({0, 0}) = ElementA(1);
tensor_B.host_view().at({0, 0}) = ElementB(1);
tensor_C.host_view().at({0, 0}) = ElementC(1);
hytlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
}
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(
hute::Shape<int,int,int,int> problem_shape_MNKL,
ElementScalar alpha,
ElementScalar beta)
{
auto [M, N, K, L] = problem_shape_MNKL;
tensor_D.sync_host();
bool passed = hytlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
if (!passed)
{
std::stringstream fname;
fname << "error_Gemm_device_"
<< M << "x" << N << "x" << K << "x" << L << "_"
<< hute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_"
<< hute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_"
<< hute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".csv";
std::ofstream file(fname.str());
file
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n";
file
<< "A =\n" << tensor_A.host_view()
<< "\nB =\n" << tensor_B.host_view()
<< "\nC =\n" << tensor_C.host_view()
<< "\n\nReference =\n" << reference_D.host_view()
<< "\n\nComputed =\n" << tensor_D.host_view();
}
return passed;
}
/// Verifies the result is a GEMM
bool verify(
ProblemShapeType problem_size,
ElementScalar alpha,
ElementScalar beta)
{
auto problem_shape_MNKL = hute::append<4>(problem_size, 1);
auto M = hute::size<0>(problem_shape_MNKL);
auto N = hute::size<1>(problem_shape_MNKL);
auto K = hute::size<2>(problem_shape_MNKL);
auto L = hute::size<3>(problem_shape_MNKL);
auto A = hute::make_tensor(tensor_A.host_data(),
hute::make_layout(hute::make_shape(M, K, L), stride_a));
auto B = hute::make_tensor(tensor_B.host_data(),
hute::make_layout(hute::make_shape(N, K, L), stride_b));
auto C = hute::make_tensor(tensor_C.host_data(),
hute::make_layout(hute::make_shape(M, N, L), stride_c));
auto D = hute::make_tensor(reference_D.host_data(),
hute::make_layout(hute::make_shape(M, N, L), stride_d));
auto Bias = hute::make_tensor(static_cast<ElementCompute*>(nullptr),
hute::make_layout(hute::make_shape(M, hute::_1{})));
auto Aux = hute::make_tensor(static_cast<ElementD*>(nullptr),
hute::make_layout(hute::make_shape(M, N, L), stride_d));
auto Valpha = hute::make_tensor(static_cast<ElementCompute*>(nullptr),
hute::make_layout(hute::make_shape(M, hute::_1{})));
auto Vbeta = hute::make_tensor(static_cast<ElementCompute*>(nullptr),
hute::make_layout(hute::make_shape(M, hute::_1{})));
hytlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
hytlass::reference::host::GettEpilogueParams<
ElementScalar,
ElementScalar,
ElementAccumulator,
ElementCompute,
decltype(C),
decltype(D),
decltype(Bias),
decltype(Aux),
decltype(Valpha),
decltype(Vbeta),
ActivationFunctor
>
epilogue_params{
alpha, beta,
C, D, Bias, Aux
, Valpha, Vbeta
};
hytlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
return compare_reference(problem_shape_MNKL, alpha, beta);
}
/// Determine if the GFX device is sufficient to run the kernel
bool sufficient() {
//
// Determine SMEM requirements and waive if not satisfied
//
int smem_size = Gemm::GemmKernel::SharedStorageSize;
int device_idx;
hipError_t result = hipGetDevice(&device_idx);
if (result != hipSuccess) {
throw std::runtime_error("hipGetDevice() API call failed.");
}
hipDeviceProp_t properties;
result = hipGetDeviceProperties(&properties, device_idx);
this->sm_count = properties.multiProcessorCount;
if (result != hipSuccess) {
throw std::runtime_error("hipGetDeviceProperties() failed");
}
if (properties.sharedMemPerBlock < smem_size) {
// return false;
}
return true;
}
bool profile(
ProblemShapeType problem_size,
int iterations,
Gemm& gemm_op,
typename Gemm::Arguments& arguments,
hytlass::device_memory::allocation<uint8_t>& workspace) {
int M = hute::size<0>(problem_size);
int N = hute::size<1>(problem_size);
int K = hute::size<2>(problem_size);
int L = 1;
if constexpr(hute::rank(ProblemShapeType{}) == 4) {
L = hute::size<3>(problem_size);
}
hytlass::Status status;
// warm-up
for (int iter = 0; iter < 10; ++iter) {
status = gemm_op(arguments, workspace.get());
}
(void)hipDeviceSynchronize();
//
// Run the GEMM
//
hipError_t result;
GPU_Clock timer;
timer.start();
double gflops = (2.0*M*N*K) * 1e-9;
for (int iter = 0; iter < iterations; ++iter) {
status = gemm_op(arguments, workspace.get());
if (status != hytlass::Status::kSuccess) {
return false;
}
}
result = hipDeviceSynchronize();
double hute_time = timer.seconds() / iterations;
HUTE_CHECK_LAST();
printf("HUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / hute_time, hute_time*1000);
if (result != hipSuccess) {
return false;
}
return true;
}
/// Executes one test
bool run(
ProblemShapeType problem_size,
ElementScalar alpha = ElementScalar(1),
ElementScalar beta = ElementScalar(0),
bool profiling = false,
detail::Iterations iterations = Iterations{},
detail::Splits splits = Splits{})
{
// Fail test if insufficient GFX device
if (!sufficient()) {
std::cout << "Test failed due to insufficient GFX device." << std::endl;
return false;
}
// this->initialize(problem_size); // 使用batch时
auto slice_k = hute::get<3>(problem_size); // 使用splitk时,走另一个重载
this->initialize(problem_size, slice_k);
//
// Initialize the GEMM operator
//
hytlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
if (not profiling) {
this->sm_count = min(MaxSmCount, hytlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = this->sm_count;
}
else {
this->sm_count = hytlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
hw_info.sm_count = this->sm_count;
}
typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args;
if constexpr (std::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, hytlass::gemm::StreamKScheduler>) {
scheduler_args = { static_cast<int>(splits) };
}
// DefaultEpilogue
auto arguments = typename Gemm::Arguments {
hytlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{
tensor_A.device_data(), stride_a,
tensor_B.device_data(), stride_b
},
{
{alpha, beta},
tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d
},
hw_info,
scheduler_args
};
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
hytlass::Status status = gemm_op.can_implement(arguments);
if (status != hytlass::Status::kSuccess) {
hipError_t error = hipGetLastError();
std::cerr << "This test is not supported: " << hipGetErrorString(error) << "\n";
return true;
}
//
// Run the GEMM
//
if (profiling) {
printf("first step: verify results\n");
hipError_t result;
status = gemm_op.initialize(arguments, workspace.get());
status = gemm_op.run();
result = hipDeviceSynchronize();
if (result != hipSuccess) {
printf("Error at Kernel Sync.\n");
return false;
}
bool passed = this->verify(problem_size, alpha, beta);
if (!passed) {
printf("%s:%d\n",__FILE__,__LINE__);
std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta)
<< "\n";
}
else
{
printf("%s:%d check passed\n",__FILE__,__LINE__);
}
return profile(problem_size, static_cast<int>(iterations), gemm_op, arguments, workspace);
}
else {
hipError_t result;
status = gemm_op.initialize(arguments, workspace.get());
status = gemm_op.run();
result = hipDeviceSynchronize();
if (result != hipSuccess) {
printf("Error at Kernel Sync.\n");
return false;
}
printf("verify results\n");
bool passed = this->verify(problem_size, alpha, beta);
if (!passed) {
printf("%s:%d\n",__FILE__,__LINE__);
std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta)
<< "\n";
}
else
{
printf("%s:%d check passed\n",__FILE__,__LINE__);
}
return passed;
}
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Gemm,
template <class T> class ActivationFunctor
>
struct Testbed3x {
using TestBedImpl = typename detail::TestbedImpl<Gemm, ActivationFunctor>;
using Kernel = typename Gemm::GemmKernel;
using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue;
using ElementAccumulator = typename TestBedImpl::ElementAccumulator;
using ElementCompute = typename TestBedImpl::ElementCompute;
using ElementScalar = typename TestBedImpl::ElementScalar;
using LayoutTagA = typename TestBedImpl::LayoutTagA;
using LayoutTagB = typename TestBedImpl::LayoutTagB;
using LayoutTagC = typename TestBedImpl::LayoutTagC;
using LayoutTagD = typename TestBedImpl::LayoutTagD;
// Detail Implementation
TestBedImpl impl_;
//
// Methods
//
Testbed3x(
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint64_t seed_ = TestBedImpl::kDefaultSeed)
: impl_(init_A_, init_B_, init_C_, seed_) {}
Testbed3x(
typename LayoutTagA::Stride stride_factor_A_,
typename LayoutTagB::Stride stride_factor_B_,
typename LayoutTagC::Stride stride_factor_C_,
typename LayoutTagD::Stride stride_factor_D_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint64_t seed_ = TestBedImpl::kDefaultSeed)
: impl_(stride_factor_A_,
stride_factor_B_,
stride_factor_C_,
stride_factor_D_,
init_A_,
init_B_,
init_C_,
seed_) {}
/// Executes one test
bool run(
typename TestBedImpl::ProblemShapeType problem_size,
ElementScalar alpha = ElementScalar(1),
ElementScalar beta = ElementScalar(0),
detail::Splits splits = detail::Splits{},
bool profiling = false,
detail::Iterations iterations = detail::Iterations{})
{
return impl_.run(
problem_size, alpha, beta, profiling, iterations, splits
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Gemm,
typename Testbed = Testbed3x<Gemm, hytlass::epilogue::thread::Identity>
>
bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) {
using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
std::vector<int> problem_size_m = {256};
std::vector<int> problem_size_n = {256};
if constexpr (std::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
hytlass::gemm::KernelTmaWarpSpecializedPingpong>) {
problem_size_m.push_back(768);
problem_size_n.push_back(768);
}
constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
constexpr int TileShapeK = hute::size<2>(typename Gemm::GemmKernel::TileShape{});
std::vector<int> problem_size_k = {32};
std::vector<int> problem_splits = {1};
if constexpr (std::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, hytlass::gemm::StreamKScheduler>) {
problem_splits.push_back(2);
problem_splits.push_back(3);
// As many splits as there are maximum k tiles
problem_splits.push_back(Stages + 1);
}
bool passed = true;
for (int m : problem_size_m) {
for (int n : problem_size_n) {
for (int k : problem_size_k) {
for (int splits : problem_splits) {
ProblemShapeType problem_size;
if constexpr (hute::rank(ProblemShapeType{}) == 4) {
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
}
else {
problem_size = ProblemShapeType{m, n, k};
}
printf("problem size:%d %d %d\n",m,n,k);
passed = testbed.run(
problem_size,
hytlass::from_real<ElementScalar>(alpha),
hytlass::from_real<ElementScalar>(beta),
detail::Splits(splits)
);
if (!passed) {
return false;
}
}
}
}
}
return passed;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
bool TestAllBiasElementwise(double alpha = 1.0, double beta = 0.0, bool check_relative_equality=false) {
return TestAll<Gemm>(alpha, beta, testbed);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
bool TestGemmPerf3x(int iterations = 20,int m = 4096,int n = 4096,int k = 128) {
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator;
using ElementScalar = ElementAccumulator;
bool passed = true;
std::vector<int> problem_size_m = { m };
std::vector<int> problem_size_n = { n };
std::vector<int> problem_size_k = { k };
Testbed3x<Gemm, hytlass::epilogue::thread::Identity> testbed;
for (int m : problem_size_m) {
for (int n : problem_size_n) {
for (int k : problem_size_k) {
ProblemShapeType problem_size;
if constexpr (hute::rank(ProblemShapeType{}) == 4) {
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
}
else {
problem_size = ProblemShapeType{m, n, k};
}
printf("perf test:{%d %d %d}\n",m,n,k);
passed = testbed.run(
problem_size,
hytlass::from_real<ElementScalar>(1),
hytlass::from_real<ElementScalar>(0),
detail::Splits(1),
true,
detail::Iterations(iterations)
);
if (!passed) {
return false;
}
}
}
}
return passed;
}
} // namespace device
} // namespace gemm
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "hip/hip_runtime.h"
#include <iostream>
/**
* Panic wrapper for unwinding HYTLASS errors
*/
#define HYTLASS_CHECK(status) \
{ \
hytlass::Status error = status; \
if (error != hytlass::Status::kSuccess) { \
std::cerr << "Got hytlass error: " << hytlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
/**
* Panic wrapper for unwinding hip runtime errors
*/
#define HIP_CHECK(status) \
{ \
hipError_t error = status; \
if (error != hipSuccess) { \
std::cerr << "Got bad hip status: " << hipGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
/**
* GPU timer for recording the elapsed time across kernel(s) launched in GPU stream
*/
struct GpuTimer
{
hipStream_t _stream_id;
hipEvent_t _start;
hipEvent_t _stop;
/// Constructor
GpuTimer() : _stream_id(0)
{
HIP_CHECK(hipEventCreate(&_start));
HIP_CHECK(hipEventCreate(&_stop));
}
/// Destructor
~GpuTimer()
{
HIP_CHECK(hipEventDestroy(_start));
HIP_CHECK(hipEventDestroy(_stop));
}
/// Start the timer for a given stream (defaults to the default stream)
void start(hipStream_t stream_id = 0)
{
_stream_id = stream_id;
HIP_CHECK(hipEventRecord(_start, _stream_id));
}
/// Stop the timer
void stop()
{
HIP_CHECK(hipEventRecord(_stop, _stream_id));
}
/// Return the elapsed time (in milliseconds)
float elapsed_millis()
{
float elapsed = 0.0;
HIP_CHECK(hipEventSynchronize(_stop));
HIP_CHECK(hipEventElapsedTime(&elapsed, _start, _stop));
return elapsed;
}
};
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */
#pragma nv_diag_suppress boolean_controlling_expr_is_constant
// #include <gtest/gtest.h>
#pragma nv_diag_warning boolean_controlling_expr_is_constant
#pragma warning( disable : 4503)
#include <cstdlib>
#include <string>
#include <hip/hip_runtime_api.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Prints device properties
std::ostream &operator<<(std::ostream &out, hipDeviceProp_t const &device);
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Sets flags for Unit test
void FilterArchitecture();
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Reads environment variable `HYTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order
// of problem sizes run by HYTLASS unit tests
int HytlassUnitTestProblemCount();
/////////////////////////////////////////////////////////////////////////////////////////////////
// active test macro
#define HYTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__
// disabled test macro
#define HYTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {}
#if HYTLASS_TEST_LEVEL == 0
#define HYTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#define HYTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#define HYTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#elif HYTLASS_TEST_LEVEL == 1
#define HYTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#define HYTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#define HYTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#else
#define HYTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#define HYTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#define HYTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) HYTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
#endif
#if !defined(HYTLASS_TEST_UNIT_ENABLE_WARNINGS)
#define HYTLASS_TEST_UNIT_ENABLE_WARNINGS false
#endif
#include <hytlass/hytlass.h>
#include <hytlass/numeric_types.h>
#include <hytlass/trace.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#pragma once
#include "hytlass/hytlass.h"
inline char const *to_string(hytlass::Status status) {
switch (status) {
case hytlass::Status::kSuccess: return "kSuccess";
case hytlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand";
case hytlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout";
case hytlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem";
case hytlass::Status::kErrorNotSupported: return "kErrorNotSupported";
case hytlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull";
case hytlass::Status::kErrorInternal: return "kErrorInternal";
case hytlass::Status::kInvalid: return "kInvalid";
default: break;
}
return "invalid";
}
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
add_subdirectory(tutorial)
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
sgemm_1
sgemm_1.cu
)
hytlass_example_add_executable(
sgemm_2
sgemm_2.cu
)
hytlass_example_add_executable(
sgemm_gfx928
sgemm_gfx928.cu
)
hytlass_example_add_executable(
tiled_copy
tiled_copy.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "hip/hip_runtime.h"
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <hute/tensor.hpp>
#include "hytlass/util/print_error.hpp"
#include "hytlass/util/GPU_Clock.hpp"
#include "hytlass/util/helper_hip.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class AThreadLayout,
class TB, class BStride, class BSmemLayout, class BThreadLayout,
class TC, class CStride, class CSmemLayout, class CThreadLayout,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
TC * C, CStride dC, CSmemLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
{
using namespace hute;
// Preconditions
HUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
HUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
static_assert(is_static<AThreadLayout>::value);
static_assert(is_static<BThreadLayout>::value);
static_assert(is_static<CThreadLayout>::value);
HUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads
HUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads
HUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M
HUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K
HUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N
HUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K
HUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M
HUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
HUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
HUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
HUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
HUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
HUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
HUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
HUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
HUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
HUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles
Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
HUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M
HUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K
HUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N
HUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via projections of a ThreadLayout tC
// Partition sA (M,K) by the rows of tC
Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
// Partition sB (N,K) by the cols of tC
Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
// Partition gC (M,N) by the tile of tC
Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
// Allocate the accumulators -- same shape/layout as the partitioned data
Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N)
HUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M
HUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M
HUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N
HUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N
HUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,
// and then computes on those tiles.
// copy(.) operates on the global and shared memory via the tA|tB partitioning
// gemm(.) operates on the shared and register memory via the tC partitioning
auto K_TILE_MAX = size<2>(tAgA);
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Copy gmem to smem with tA|tB thread-partitioned tensors
copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K)
copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K)
// TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to
// Tensor tAgAk = tAgA(_,_,k_tile);
// HUTE_UNROLL
// for (int i = 0; i < size(tAsA); ++i) {
// tAsA(i) = tAgAk(i);
// }
__syncthreads(); // Wait for all threads to write to smem
// Compute gemm on tC thread-partitioned smem
gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
// HUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// HUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// HUTE_UNROLL
// for (int n = 0; n < size<1>(tCrC); ++n) {
// tCrC(m,n) += tCsA(m,k) * tCsB(n,k);
// }
// }
// }
__syncthreads(); // Wait for all threads to read from smem
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
// TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to
// HUTE_UNROLL
// for (int i = 0; i < size(tCsA); ++i) {
// tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);
// }
}
// Setup params for an NT GEMM
// Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
using namespace hute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx
dim3 dimBlock(size(tC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
// Setup params for a TN GEMM
// Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
using namespace hute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major
auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major
dim3 dimBlock(size(tC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
hute::device_init(0);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
HUTE_CHECK_LAST();
thrust::host_vector<TC> hute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double hute_time = timer.seconds() / timing_iterations;
HUTE_CHECK_LAST();
printf("HUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / hute_time, hute_time*1000);
return 0;
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "hip/hip_runtime.h"
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <hute/tensor.hpp>
#include "hytlass/util/print_error.hpp"
#include "hytlass/util/GPU_Clock.hpp"
#include "hytlass/util/helper_hip.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
using namespace hute;
// Preconditions
HUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
HUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
HUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
HUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
HUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
HUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
HUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
HUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
HUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
HUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
HUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
HUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
HUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of partitioning via a TiledCopy
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
HUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
HUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
HUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
HUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
HUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
HUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
HUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
HUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
// Copy gmem to rmem for k_tile=0
copy(copy_a, tAgA(_,_,_,0), tArA);
copy(copy_b, tBgB(_,_,_,0), tBrB);
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via a TiledMMA
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
HUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
HUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
HUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
HUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// TUTORIAL: Example of an inner loop that pipelines compute with reads
// from global memory by staging through register and shared memory.
// Data is read from global to registers, then to shared via the TiledCopy partitions
// gemm(.) operates on the shared memory directly via the TiledMMA partitions
auto K_TILE_MAX = size<3>(tAgA);
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Copy rmem to smem with tA|tB thread-partitioned tensors
__syncthreads(); // Wait for all threads to consume smem
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads(); // Wait for all threads to consume smem
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
// TUTORIAL: The above call to copy(copy_a, tAgA(_,_,_,k_tile_next), tArA) is equivalent to
// HUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// HUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// copy_a.call(tAgA(_,m,k), tArA(_,m,k);
// }
// }
// Compute gemm on mma-partitioned smem
gemm(mma, tCsA, tCsB, tCrC);
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
// HUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// HUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// HUTE_UNROLL
// for (int n = 0; n < size<1>(tCrC); ++n) {
// mma.call(tCsA(_,m,k), tCsB(_,n,k), tCrC(_,m,n);
// }
// }
// }
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
using namespace hute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
// TUTORIAL: Construct TiledCopy with a particular Copy_Atom to use and
// define the partitioning pattern to apply.
// Each thread will (try to) copy 4x1 elements of type TA using 128-bit copy.
// Use 32x8 of these threads.
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TB>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
// TUTORIAL: Construct TiledMMA with a particular MMA_Atom to use and
// define the partitioning pattern to apply.
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 UniversalFMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
using namespace hute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
// TUTORIAL: Construct TiledCopy to define the Copy_Atom to use and the
// partitioning pattern to apply.
// Each thread will copy 1x1 elements of type TA.
// Use 32x8 of these threads arranged in k-major.
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
// TUTORIAL: Construct TiledMMA to define the MMA_Atom to use and the
// partitioning pattern to apply.
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
hute::device_init(0);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
HUTE_CHECK_LAST();
thrust::host_vector<TC> hute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double hute_time = timer.seconds() / timing_iterations;
HUTE_CHECK_LAST();
printf("HUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / hute_time, hute_time*1000);
return 0;
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "hip/hip_runtime.h"
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include "hute/layout.hpp"
#include <hute/tensor.hpp>
#include "hytlass/util/print_error.hpp"
#include "hytlass/util/GPU_Clock.hpp"
#include "hytlass/util/helper_hip.hpp"
#if defined(HYTLASS_ENABLE_HIPBLAS) && HYTLASS_ENABLE_HIPBLAS != 0
# include "hytlass/util/hipblas_wrappers.hpp"
#endif
template <typename T>
void cpu_gemm(T *c, const T &a, const T &b) {
using namespace hute;
using ValueType = typename T::value_type;
int m = size<0>(a);
int n = size<0>(b);
int k = size<1>(a);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float s = 0.f;
for (int kk = 0; kk < k; ++kk) {
float v1 = a(i, kk);
float v2 = b(j, kk);
s += v1 * v2;
}
(*c)(i, j) = ValueType(s);
}
}
}
template <typename T>
void cpu_compare(const T &x, const T &y, float threshold) {
using namespace hute;
if (size(x) != size(y)) {
fprintf(stderr, "lenght not equal x = %d, y = %d\n", size(x), size(y));
exit(9);
}
int n = size(x);
float diff_max = 0;
int diff_count = 0;
for (int i = 0; i < n; ++i) {
float v0 = x(i);
float v1 = y(i);
// printf("[%d]: %f %f\n",i,v0,v1);
diff_max = max(diff_max, fabs(v0 - v1));
if (fabs(v0 - v1) > threshold) {
printf("[%d]: %f %f\n",i,v0,v1);
++diff_count;
}
}
if (diff_count > 0) {
printf("check fail: max_diff = %f, diff_count = %d\n", diff_max,
diff_count);
} else {
printf("cpu check ok\n");
}
}
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
using namespace hute;
using X = Underscore;
// Preconditions
HUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
HUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
HUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
HUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
HUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
HUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
HUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
HUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
HUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
HUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
HUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
HUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
HUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of partitioning via a TiledCopy
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
HUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
HUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
HUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
HUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
HUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
HUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
HUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
HUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
// Copy gmem to rmem for k_tile=0
copy(copy_a, tAgA(_,_,_,0), tArA);
copy(copy_b, tBgB(_,_,_,0), tBrB);
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via a TiledMMA
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate registers for pipelining
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
HUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
HUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
HUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
HUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
HUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
HUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// Copy rmem to smem
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads();
//
// PIPELINED MAIN LOOP
// TUTORIAL: Example of a gemm loop that pipelines shared memory AND register memory
// Data is read from global to registers, then to shared via the tA|tB partitions
// Data is then copied from shared to registers in multiple waves via the tC partitions
// and gemm(.) operates on the current register wave
//
// Load A, B shmem->regs for k_block=0
copy(tCsA(_,_,0), tCrA(_,_,0));
copy(tCsB(_,_,0), tCrB(_,_,0));
auto K_TILE_MAX = size<3>(tAgA);
auto K_BLOCK_MAX = size<2>(tCrA);
HUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Pipeline the k-mode of the block registers
HUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
if (k_block == K_BLOCK_MAX - 1)
{
// Copy rmem to smem
__syncthreads();
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads();
}
// Copy smem to rmem for k_block+1
int k_block_next = (k_block + 1) % K_BLOCK_MAX;
copy(tCsA(_,_,k_block_next), tCrA(_,_,k_block_next));
copy(tCsB(_,_,k_block_next), tCrB(_,_,k_block_next));
if (k_block == 0)
{
// Copy gmem to rmem for k_tile+1
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
}
// Thread-level register gemm for k_block
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
} // k_block
} // k_tile
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
using namespace hute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TB>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
TiledMMA mmaC = make_tiled_mma(GFX928_16x16x8_F32F32F32F32_NT{},
Layout<Shape<_2,_2,_1>>{}, // Layout in Thr
Layout<Shape<_1,_1,_1>>{} // Layout in Val
);
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
using namespace hute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
// Define the thread layouts (static)
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledMMA mmaC = make_tiled_mma(GFX928_16x16x8_F32F32F32F32_NT{},
Layout<Shape<_2,_2,_1>>{}, // Layout in Thr
Layout<Shape<_1,_1,_1>>{} // Layout in Val
);
// print_latex(tiled_mma);
// // print_latex(mma_op);
#if 0
print(copyA);
print(copyB);
print(mmaC);
// printf("size(copyA)=%d\n",size(copyA));
// printf("size(copyB)=%d\n",size(copyB));
// printf("size(mmaC)=%d\n",size(mmaC));
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
hipStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
hipDeviceProp_t props;
hipError_t error = hipGetDeviceProperties(&props, 0);
if (error != hipSuccess) {
std::cerr << "hipGetDeviceProperties() returned an error: " << hipGetErrorString(error) << std::endl;
return -1;
}
int m = 2048;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 2048;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 1024;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = TI(1.0);
TI beta = TI(0.0);
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
// for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( j );
// for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( j );
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 1;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
#if defined(HYTLASS_ENABLE_HIPBLAS) && HYTLASS_ENABLE_HIPBLAS != 0
hipblasHandle_t handle;
hipblasCreate(&handle);
// Run once
d_C = h_C;
blam::hipblas::gemm(handle,
transA == 'N' ? HIPBLAS_OP_N :HIPBLAS_OP_T,
transB == 'N' ? HIPBLAS_OP_N :HIPBLAS_OP_T,
m, n, k,
&alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
&beta,
d_C.data().get(), m);
HUTE_CHECK_LAST();
thrust::host_vector<TC> hipblas_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
blam::hipblas::gemm(handle,
transA == 'N' ? HIPBLAS_OP_N :HIPBLAS_OP_T,
transB == 'N' ? HIPBLAS_OP_N :HIPBLAS_OP_T,
m, n, k,
&alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
&beta,
d_C.data().get(), m);
}
double hipblas_time = timer.seconds() / timing_iterations;
HUTE_CHECK_LAST();
printf("BLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / hipblas_time, hipblas_time*1000);
#else
std::cout << "Verification by comparison with hipBLAS is disabled, "
"either because the CMake option HYTLASS_ENABLE_HIPBLAS "
"was explicitly set to OFF, or because CMake could not find hipBLAS. "
"If you would like to enable verification with hipBLAS, "
"please set the CMake option HYTLASS_ENABLE_HIPBLAS to ON, "
"rerun CMake, and recompile this example.\n";
#endif // HYTLASS_ENABLE_HIPBLAS
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
HUTE_CHECK_LAST();
thrust::host_vector<TC> hute_result = d_C;
if(transA == 'N' && transB == 'T')
{
using namespace hute;
// tn
TI * Dptr_host_cpu = (TI *)malloc(sizeof(TI) * m * n);
TI * Dptr_host = (TI *)malloc(sizeof(TI) * m * n);
TI * Dptr_host_A = &h_A[0];
TI * Dptr_host_B = &h_B[0];
auto tD_host_cpu =
make_tensor(Dptr_host_cpu, make_shape(m, n), make_stride(1, m));
auto tA = make_tensor(Dptr_host_A, make_shape(m, k), make_stride(1, m));
auto tB = make_tensor(Dptr_host_B, make_shape(n, k), make_stride(1, n));
cpu_gemm(&tD_host_cpu, tA, tB);
(void)hipMemcpy(Dptr_host, d_C.data().get(), sizeof(TI) * m * n, hipMemcpyDeviceToHost);
auto tD_host_hute = make_tensor(Dptr_host, make_shape(m, n), make_stride(1, m));
cpu_compare(tD_host_hute,tD_host_cpu,0.1);
}
else if(transA == 'T' && transB == 'N')
{
using namespace hute;
TI * Dptr_host_cpu = (TI *)malloc(sizeof(TI) * m * n);
TI * Dptr_host = (TI *)malloc(sizeof(TI) * m * n);
TI * Dptr_host_A = &h_A[0];
TI * Dptr_host_B = &h_B[0];
auto tD_host_cpu =
make_tensor(Dptr_host_cpu, make_shape(m, n), make_stride(1, m));
auto tA = make_tensor(Dptr_host_A, make_shape(m, k), make_stride(k,1));
auto tB = make_tensor(Dptr_host_B, make_shape(n, k), make_stride(k,1));
cpu_gemm(&tD_host_cpu, tA, tB);
(void)hipMemcpy(Dptr_host, d_C.data().get(), sizeof(TI) * m * n, hipMemcpyDeviceToHost);
auto tD_host_hute = make_tensor(Dptr_host, make_shape(m, n), make_stride(1, m));
cpu_compare(tD_host_hute,tD_host_cpu,0.1);
}
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double hute_time = timer.seconds() / timing_iterations;
HUTE_CHECK_LAST();
printf("HUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / hute_time, hute_time*1000);
#if defined(HYTLASS_ENABLE_HIPBLAS) && HYTLASS_ENABLE_HIPBLAS != 0
printf("Empirical Perf: %.1f%%\n", (hipblas_time / hute_time) * 100);
auto host_matrix_to_const_column_major_hute_tensor =
[](const auto& X, int num_rows, int num_cols, int LDX) {
const auto shape = hute::Shape<int, int>{num_rows, num_cols};
const auto strides = hute::Stride<int, int>{1, LDX};
return hute::make_tensor(X.data(), hute::make_layout(shape, strides));
};
const auto A_view = host_matrix_to_const_column_major_hute_tensor(h_A, m, k, m);
// B^T is k x n, so B is n x k.
const auto B_view = host_matrix_to_const_column_major_hute_tensor(h_B, n, k, n);
const auto C_computed_view = host_matrix_to_const_column_major_hute_tensor(hute_result, m, n, m);
const auto C_expected_view = host_matrix_to_const_column_major_hute_tensor(hipblas_result, m, n, m);
print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view);
#endif // HYTLASS_ENABLE_HIPBLAS
return 0;
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <hute/tensor.hpp>
#include "hytlass/util/print_error.hpp"
#include "hytlass/util/GPU_Clock.hpp"
#include "hytlass/util/helper_hip.hpp"
// This is a simple tutorial showing several ways to partition a tensor into tiles then
// perform efficient, coalesced copies. This example also shows how to vectorize accesses
// which may be a useful optimization or required for certain workloads.
//
// `copy_kernel()` and `copy_kernel_vectorized()` each assume a pair of tensors with
// dimensions (m, n) have been partitioned via `tiled_divide()`.
//
// The result are a part of compatible tensors with dimensions ((M, N), m', n'), where
// (M, N) denotes a statically sized tile, and m' and n' denote the number of such tiles
// within the tensor.
//
// Each statically sized tile is mapped to a threadblock which performs efficient
// loads and stores to Global Memory.
//
// `copy_kernel()` uses `hute::local_partition()` to partition the tensor and map
// the result to threads using a striped indexing scheme. Threads themselve are arranged
// in a (ThreadShape_M, ThreadShape_N) arrangement which is replicated over the tile.
//
// `copy_kernel_vectorized()` uses `hute::make_tiled_copy()` to perform a similar
// partitioning using `hute::Copy_Atom` to perform vectorization. The actual vector
// size is defined by `ThreadShape`.
//
// This example assumes the overall tensor shape is divisible by the tile size and
// does not perform predication.
/// Simple copy kernel.
//
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
template <class TensorS, class TensorD, class ThreadLayout>
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
{
using namespace hute;
// Slice the tiled tensors
Tensor tile_S = S(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
// Construct a partitioning of the tile among threads with the given thread arrangement.
// Concept: Tensor ThrLayout ThrIndex
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
// Construct a register-backed Tensor with the same shape as each thread's partition
// Use make_tensor to try to match the layout of thr_tile_S
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(thr_tile_S, fragment);
copy(fragment, thr_tile_D);
}
/// Vectorized copy kernel.
///
/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation
/// has the precondition that pointers are aligned to the vector size.
///
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
{
using namespace hute;
using Element = typename TensorS::value_type;
// Slice the tensors to obtain a view into each tile.
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
// Define `AccessType` which controls the size of the actual memory access.
using AccessType = hytlass::AlignedArray<Element, size(VecLayout{})>;
// A copy atom corresponds to one hardware memory access.
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
// Construct tiled copy, a tiling of copy atoms.
//
// Note, this assumes the vector and thread layouts are aligned with contigous data
// in GMEM. Alternative thread layouts are possible but may result in uncoalesced
// reads. Alternative vector layouts are also possible, though incompatible layouts
// will result in compile time errors.
auto tiled_copy =
make_tiled_copy(
Atom{}, // access size
ThreadLayout{}, // thread layout
VecLayout{}); // vector layout (e.g. 4x1)
// Construct a Tensor corresponding to each thread's slice.
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN)
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN)
// Construct a register-backed Tensor with the same shape as each thread's partition
// Use make_fragment because the first mode is the instruction-local mode
Tensor fragment = make_fragment_like(thr_tile_D); // (CopyOp, CopyM, CopyN)
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(tiled_copy, fragment, thr_tile_D);
}
/// Main function
int main(int argc, char** argv)
{
//
// Given a 2D shape, perform an efficient copy
//
using namespace hute;
using Element = float;
// Define a tensor shape with dynamic extents (m, n)
auto tensor_shape = make_shape(256, 512);
//
// Allocate and initialize
//
thrust::host_vector<Element> h_S(size(tensor_shape));
thrust::host_vector<Element> h_D(size(tensor_shape));
for (size_t i = 0; i < h_S.size(); ++i) {
h_S[i] = static_cast<Element>(i);
h_D[i] = Element{};
}
thrust::device_vector<Element> d_S = h_S;
thrust::device_vector<Element> d_D = h_D;
//
// Make tensors
//
Tensor tensor_S = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), make_layout(tensor_shape));
Tensor tensor_D = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), make_layout(tensor_shape));
//
// Tile tensors
//
// Define a statically sized block (M, N).
// Note, by convention, capital letters are used to represent static modes.
auto block_shape = make_shape(Int<128>{}, Int<64>{});
if ((size<0>(tensor_shape) % size<0>(block_shape)) || (size<1>(tensor_shape) % size<1>(block_shape))) {
std::cerr << "The tensor shape must be divisible by the block shape." << std::endl;
return -1;
}
// Equivalent check to the above
if (not weakly_compatible(block_shape, tensor_shape)) {
std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl;
return -1;
}
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
// shape, and modes (m', n') correspond to the number of tiles.
//
// These will be used to determine the kernel grid dimensions.
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
// Thread arrangement
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{}));
// Vector dimensions
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
//
// Determine grid and block dimensions
//
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
dim3 blockDim(size(thr_layout));
//
// Launch the kernel
//
copy_kernel_vectorized<<< gridDim, blockDim >>>(
tiled_tensor_S,
tiled_tensor_D,
thr_layout,
vec_layout);
hipError_t result = hipDeviceSynchronize();
if (result != hipSuccess) {
std::cerr << "HIP Runtime error: " << hipGetErrorString(result) << std::endl;
return -1;
}
//
// Verify
//
h_D = d_D;
int32_t errors = 0;
int32_t const kErrorLimit = 10;
for (size_t i = 0; i < h_D.size(); ++i) {
if (h_S[i] != h_D[i]) {
std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i << "]: " << h_D[i] << std::endl;
if (++errors >= kErrorLimit) {
std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl;
return -1;
}
}
}
std::cout << "Success." << std::endl;
return 0;
}
# Run manually to reformat a file:
# clang-format -i --style=file <file>
Language: Cpp
BasedOnStyle: Google
name: Bug Report
description: Let us know that something does not work as expected.
title: "[Bug]: Please title this bug report"
body:
- type: textarea
id: what-happened
attributes:
label: Describe the issue
description: What happened, and what did you expect to happen?
validations:
required: true
- type: textarea
id: steps
attributes:
label: Steps to reproduce the problem
description: It is important that we are able to reproduce the problem that you are experiencing. Please provide all code and relevant steps to reproduce the problem, including your `BUILD`/`CMakeLists.txt` file and build commands. Links to a GitHub branch or [godbolt.org](https://godbolt.org/) that demonstrate the problem are also helpful.
validations:
required: true
- type: textarea
id: version
attributes:
label: What version of GoogleTest are you using?
description: Please include the output of `git rev-parse HEAD` or the GoogleTest release version number that you are using.
validations:
required: true
- type: textarea
id: os
attributes:
label: What operating system and version are you using?
description: If you are using a Linux distribution please include the name and version of the distribution as well.
validations:
required: true
- type: textarea
id: compiler
attributes:
label: What compiler and version are you using?
description: Please include the output of `gcc -v` or `clang -v`, or the equivalent for your compiler.
validations:
required: true
- type: textarea
id: buildsystem
attributes:
label: What build system are you using?
description: Please include the output of `bazel --version` or `cmake --version`, or the equivalent for your build system.
validations:
required: true
- type: textarea
id: additional
attributes:
label: Additional context
description: Add any other context about the problem here.
validations:
required: false
name: Feature request
description: Propose a new feature.
title: "[FR]: Please title this feature request"
labels: "enhancement"
body:
- type: textarea
id: version
attributes:
label: Does the feature exist in the most recent commit?
description: We recommend using the latest commit from GitHub in your projects.
validations:
required: true
- type: textarea
id: why
attributes:
label: Why do we need this feature?
description: Ideally, explain why a combination of existing features cannot be used instead.
validations:
required: true
- type: textarea
id: proposal
attributes:
label: Describe the proposal.
description: Include a detailed description of the feature, with usage examples.
validations:
required: true
- type: textarea
id: platform
attributes:
label: Is the feature specific to an operating system, compiler, or build system version?
description: If it is, please specify which versions.
validations:
required: true
blank_issues_enabled: false
contact_links:
- name: Get Help
url: https://github.com/google/googletest/discussions
about: Please ask and answer questions here.
# Ignore CI build directory
build/
xcuserdata
cmake-build-debug/
.idea/
bazel-bin
bazel-genfiles
bazel-googletest
bazel-out
bazel-testlogs
MODULE.bazel.lock
# python
*.pyc
# Visual Studio files
.vs
*.sdf
*.opensdf
*.VC.opendb
*.suo
*.user
_ReSharper.Caches/
Win32-Debug/
Win32-Release/
x64-Debug/
x64-Release/
# VSCode files
.cache/
cmake-variants.yaml
# Ignore autoconf / automake files
Makefile.in
aclocal.m4
configure
build-aux/
autom4te.cache/
googletest/m4/libtool.m4
googletest/m4/ltoptions.m4
googletest/m4/ltsugar.m4
googletest/m4/ltversion.m4
googletest/m4/lt~obsolete.m4
googlemock/m4
# Ignore generated directories.
googlemock/fused-src/
googletest/fused-src/
# macOS files
.DS_Store
googletest/.DS_Store
googletest/xcode/.DS_Store
# Ignore cmake generated directories and files.
CMakeFiles
CTestTestfile.cmake
Makefile
cmake_install.cmake
googlemock/CMakeFiles
googlemock/CTestTestfile.cmake
googlemock/Makefile
googlemock/cmake_install.cmake
googlemock/gtest
/bin
/googlemock/gmock.dir
/googlemock/gmock_main.dir
/googlemock/RUN_TESTS.vcxproj.filters
/googlemock/RUN_TESTS.vcxproj
/googlemock/INSTALL.vcxproj.filters
/googlemock/INSTALL.vcxproj
/googlemock/gmock_main.vcxproj.filters
/googlemock/gmock_main.vcxproj
/googlemock/gmock.vcxproj.filters
/googlemock/gmock.vcxproj
/googlemock/gmock.sln
/googlemock/ALL_BUILD.vcxproj.filters
/googlemock/ALL_BUILD.vcxproj
/lib
/Win32
/ZERO_CHECK.vcxproj.filters
/ZERO_CHECK.vcxproj
/RUN_TESTS.vcxproj.filters
/RUN_TESTS.vcxproj
/INSTALL.vcxproj.filters
/INSTALL.vcxproj
/googletest-distribution.sln
/CMakeCache.txt
/ALL_BUILD.vcxproj.filters
/ALL_BUILD.vcxproj
# Copyright 2017 Google Inc.
# All Rights Reserved.
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Bazel Build for Google C++ Testing Framework(Google Test)
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
exports_files(["LICENSE"])
config_setting(
name = "qnx",
constraint_values = ["@platforms//os:qnx"],
)
config_setting(
name = "windows",
constraint_values = ["@platforms//os:windows"],
)
config_setting(
name = "freebsd",
constraint_values = ["@platforms//os:freebsd"],
)
config_setting(
name = "openbsd",
constraint_values = ["@platforms//os:openbsd"],
)
# NOTE: Fuchsia is not an officially supported platform.
config_setting(
name = "fuchsia",
constraint_values = ["@platforms//os:fuchsia"],
)
config_setting(
name = "msvc_compiler",
flag_values = {
"@bazel_tools//tools/cpp:compiler": "msvc-cl",
},
visibility = [":__subpackages__"],
)
config_setting(
name = "has_absl",
values = {"define": "absl=1"},
)
# Library that defines the FRIEND_TEST macro.
cc_library(
name = "gtest_prod",
hdrs = ["googletest/include/gtest/gtest_prod.h"],
includes = ["googletest/include"],
)
# Google Test including Google Mock
# For an actual test, use `gtest` and also `gtest_main` if you depend on gtest's
# main(). For a library, use `gtest_for_library` instead if the library can be
# testonly.
cc_library(
name = "gtest",
srcs = glob(
include = [
"googletest/src/*.cc",
"googletest/src/*.h",
"googletest/include/gtest/**/*.h",
"googlemock/src/*.cc",
"googlemock/include/gmock/**/*.h",
],
exclude = [
"googletest/src/gtest-all.cc",
"googletest/src/gtest_main.cc",
"googlemock/src/gmock-all.cc",
"googlemock/src/gmock_main.cc",
],
),
hdrs = glob([
"googletest/include/gtest/*.h",
"googlemock/include/gmock/*.h",
]),
copts = select({
":qnx": [],
":windows": [],
"//conditions:default": ["-pthread"],
}),
defines = select({
":has_absl": ["GTEST_HAS_ABSL=1"],
"//conditions:default": [],
}),
features = select({
":windows": ["windows_export_all_symbols"],
"//conditions:default": [],
}),
includes = [
"googlemock",
"googlemock/include",
"googletest",
"googletest/include",
],
linkopts = select({
":qnx": ["-lregex"],
":windows": [],
":freebsd": [
"-lm",
"-pthread",
],
":openbsd": [
"-lm",
"-pthread",
],
"//conditions:default": ["-pthread"],
}),
deps = select({
":has_absl": [
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/debugging:failure_signal_handler",
"@abseil-cpp//absl/debugging:stacktrace",
"@abseil-cpp//absl/debugging:symbolize",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/flags:parse",
"@abseil-cpp//absl/flags:reflection",
"@abseil-cpp//absl/flags:usage",
"@abseil-cpp//absl/strings",
"@re2",
],
"//conditions:default": [],
}) + select({
# `gtest-death-test.cc` has `EXPECT_DEATH` that spawns a process,
# expects it to crash and inspects its logs with the given matcher,
# so that's why these libraries are needed.
# Otherwise, builds targeting Fuchsia would fail to compile.
":fuchsia": [
"@fuchsia_sdk//pkg/fdio",
"@fuchsia_sdk//pkg/zx",
],
"//conditions:default": [],
}),
)
# `gtest`, but testonly. See guidance on `gtest` for when to use this.
alias(
name = "gtest_for_library",
testonly = True,
actual = ":gtest",
)
# Implements main() for tests using gtest. Prefer to depend on `gtest` as well
# to ensure compliance with the layering_check Bazel feature where only the
# direct hdrs values are available.
cc_library(
name = "gtest_main",
srcs = ["googlemock/src/gmock_main.cc"],
features = select({
":windows": ["windows_export_all_symbols"],
"//conditions:default": [],
}),
deps = [":gtest"],
)
# The following rules build samples of how to use gTest.
cc_library(
name = "gtest_sample_lib",
srcs = [
"googletest/samples/sample1.cc",
"googletest/samples/sample2.cc",
"googletest/samples/sample4.cc",
],
hdrs = [
"googletest/samples/prime_tables.h",
"googletest/samples/sample1.h",
"googletest/samples/sample2.h",
"googletest/samples/sample3-inl.h",
"googletest/samples/sample4.h",
],
features = select({
":windows": ["windows_export_all_symbols"],
"//conditions:default": [],
}),
)
cc_test(
name = "gtest_samples",
size = "small",
# All Samples except:
# sample9 (main)
# sample10 (main and takes a command line option and needs to be separate)
srcs = [
"googletest/samples/sample1_unittest.cc",
"googletest/samples/sample2_unittest.cc",
"googletest/samples/sample3_unittest.cc",
"googletest/samples/sample4_unittest.cc",
"googletest/samples/sample5_unittest.cc",
"googletest/samples/sample6_unittest.cc",
"googletest/samples/sample7_unittest.cc",
"googletest/samples/sample8_unittest.cc",
],
linkstatic = 0,
deps = [
"gtest_sample_lib",
":gtest_main",
],
)
cc_test(
name = "sample9_unittest",
size = "small",
srcs = ["googletest/samples/sample9_unittest.cc"],
deps = [":gtest"],
)
cc_test(
name = "sample10_unittest",
size = "small",
srcs = ["googletest/samples/sample10_unittest.cc"],
deps = [":gtest"],
)
# Note: CMake support is community-based. The maintainers do not use CMake
# internally.
cmake_minimum_required(VERSION 3.16)
project(googletest-distribution)
set(GOOGLETEST_VERSION 1.16.0)
if(NOT CYGWIN AND NOT MSYS AND NOT ${CMAKE_SYSTEM_NAME} STREQUAL QNX)
set(CMAKE_CXX_EXTENSIONS OFF)
endif()
enable_testing()
include(CMakeDependentOption)
include(GNUInstallDirs)
# Note that googlemock target already builds googletest.
option(BUILD_GMOCK "Builds the googlemock subproject" ON)
option(INSTALL_GTEST "Enable installation of googletest. (Projects embedding googletest may want to turn this OFF.)" ON)
option(GTEST_HAS_ABSL "Use Abseil and RE2. Requires Abseil and RE2 to be separately added to the build." OFF)
if(GTEST_HAS_ABSL)
if(NOT TARGET absl::base)
find_package(absl REQUIRED)
endif()
if(NOT TARGET re2::re2)
find_package(re2 REQUIRED)
endif()
endif()
if(BUILD_GMOCK)
add_subdirectory( googlemock )
else()
add_subdirectory( googletest )
endif()
Copyright 2008, Google Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment