Commit d22dbec2 authored by zhoux's avatar zhoux
Browse files

Initial commit: release hytlass-0.1.0

parents
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run group convolution kernels using functions and data structures
provided by HYTLASS using tensor cores.
There are 2 group conv mode:
1. hytlass::conv::GroupMode::kSingleGroup
This mode is for large K problem size: k_per_group (K/groups) equals or larger than
threadblock_tile_N. One or multiple threadblocks calculate data of one group.
2. hytlass::conv::GroupMode::kMultipleGroup
This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N.
One threadblock will calculate data from more than one group.
Function profile_convolution_selecter() shows how to choose kernel with different group mode according
to problem size and threadblock_tile size.
*/
#include <iostream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_group_fprop.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/convolution.h"
#include "hytlass/util/reference/device/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes hip Gfx architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
constexpr int NumStages = 1;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
// Analytic kernel and operation for single group problem size
using AnalyticSingleGroupKernel = typename hytlass::conv::kernel::DefaultConv2dGroupFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::GroupMode::kSingleGroup,
hytlass::conv::IteratorAlgorithm::kAnalytic>::Kernel;
using AnalyticSingleGroupOperation = hytlass::conv::device::ImplicitGemmConvolution<AnalyticSingleGroupKernel>;
// Analytic kernel and operation for multiple group problem size
using AnalyticMultipleGroupKernel = typename hytlass::conv::kernel::DefaultConv2dGroupFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::GroupMode::kMultipleGroup,
hytlass::conv::IteratorAlgorithm::kAnalytic>::Kernel;
using AnalyticMultipleGroupOperation = hytlass::conv::device::ImplicitGemmConvolution<AnalyticMultipleGroupKernel>;
// Optimized kernel and operation for single group problem size
using OptimizedSingleGroupKernel = typename hytlass::conv::kernel::DefaultConv2dGroupFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::GroupMode::kSingleGroup,
hytlass::conv::IteratorAlgorithm::kOptimized
>::Kernel;
using OptimizedSingleGroupOperation = hytlass::conv::device::ImplicitGemmConvolution<OptimizedSingleGroupKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
int groups;
bool reference_check;
bool measure_performance;
int iterations;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool optimized;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
groups(1),
reference_check(true),
measure_performance(true),
iterations(20),
alpha(1),
beta(0),
optimized(false) { }
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors of hytlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 128 / hytlass::sizeof_bits<ElementInputA>::value;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
hytlass::Tensor4DCoord input_size,
hytlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("optimized")) {
optimized = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
cmd.get_cmd_line_argument("g", groups);
filter_size.c() = input_size.c() / groups;
cmd.get_cmd_line_argument("u", conv_stride.row());
cmd.get_cmd_line_argument("v", conv_stride.column());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.h() == 3 && filter_size.w() == 3) {
padding = {1, 1, 1, 1};
}
else {
filter_size.h() = 1;
filter_size.w() = 1;
padding = {0, 0, 0, 0};
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "11_hytlass_tensorop_group_conv example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward grouped convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --g=<int> Conv groups G\n\n"
<< " --u=<int> Conv stride_h\n\n"
<< " --v=<int> Conv stride_w\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/11_hytlass_tensorop_group_conv/gfx928_tensorop_group_conv2dfprop --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n"
<< "$ ./examples/11_hytlass_tensorop_group_conv/gfx928_tensorop_group_conv2dfprop --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n"
<< "$ ./examples/11_hytlass_tensorop_group_conv/gfx928_tensorop_group_conv2dfprop --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.groups << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
template <typename Conv2dOperation>
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(7),
ElementInputA(-8),
0);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(7),
ElementInputB(-8),
0);
// Fill tensor C on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D on host with zeros
hytlass::reference::host::TensorFill(
tensor_d.host_view());
// Fill tensor D for reference on host with zeros
hytlass::reference::host::TensorFill(
tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for HYTLASS Convolution
//
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Construct Conv2dProblemSize with user defined output size
hytlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices,
options.groups
);
// Construct Conv2dOperation::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename Conv2dOperation::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta},
};
//
// Initialize HYTLASS Convolution
//
Conv2dOperation implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on device...\n";
// Compute with reference implementation
hytlass::reference::device::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
hytlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_ref_d.device_ref(),
options.alpha,
options.beta
);
tensor_ref_d.sync_host();
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
bool passed = hytlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
} else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
} else {
result.reference_check = hytlass::Status::kInvalid;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
Result profile_convolution_selecter(Options const &options) {
int k_per_group = options.filter_size.n() / options.groups;
// In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups
if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode
if (options.optimized) {
std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl;
exit(-1);
} else {
std::cout << "Select AnalyticMultipleGroupOperation\n";
return profile_convolution<AnalyticMultipleGroupOperation>(options);
}
} else { // SingleGroup mode
if (options.optimized) {
std::cout << "Select OptimizedSingleGroupOperation\n";
return profile_convolution<OptimizedSingleGroupOperation>(options);
} else {
std::cout << "Select AnalyticSingleGroupOperation\n";
return profile_convolution<AnalyticSingleGroupOperation>(options);
}
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution_selecter(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
depthwise_simt_conv2dfprop
depthwise_simt_conv2dfprop.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run depthwise 2d convolution kernels using functions and data structures
provided by HYTLASS using SIMT instruction;
There are 3 types of implementations of depthwise 2d convoltion
1. kAnalytic
Implicit gemm 2d convoltion algorithm.
2. kOptimized
An optimized algorithm and supports arbitrary stride and dilation.
3. kFixedStrideDilation
An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do
more optimizations.
In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter
size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If
in this case, please use kOptimized.
For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve
better perf, when the output tensor size is large, splitk should be enabled to achieve better perf.
In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d
convolution kernel.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_depthwise_fprop.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/conv/device/direct_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = hytlass::half_t; // Data type of accumulator
using ElementComputeEpilogue = hytlass::half_t; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = hytlass::half_t; // Data type of elements in output tensor
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassSimt;
// This code section describes hip Gfx architecture number
using SmArch = hytlass::arch::Gfx906;
// This code section describes the groups a thread block will compute
constexpr int groups_per_cta = 64;
// This code section describes the output tile <N, O, P, Q> a thread block will compute
using ThreadBlockOutputShape = hytlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>;
// This code section describes the filter shape <R, S>
using FilterShape = hytlass::MatrixShape<3, 3>;
// Threadblock tile shape
using ThreadblockShape =
hytlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, groups_per_cta, FilterShape::kCount>;
// This code section describes tile size a warp will computes
// WarpShape::kM = P * Q the warps would process
// WarpShape::kN = groups_per_cta that the warps would process
// WarpShape::kK = filter_size that the warps would process
using WarpShape = hytlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>;
// This code section describes the size of MMA op
using InstructionShape = hytlass::gemm::GemmShape<1, 1, 1>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock =
hytlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>;
// Number of pipelines you want to use
constexpr int NumStages = 2;
// This code section describe iterator algorithm selected is kFixedStrideDilation
static hytlass::conv::IteratorAlgorithm const IteratorAlgorithm =
hytlass::conv::IteratorAlgorithm::kFixedStrideDilation;
using StrideShape = hytlass::MatrixShape<1, 1>;
using DilationShape = hytlass::MatrixShape<1, 1>;
constexpr int kEpilogueElementsPerAccess = 128 / hytlass::sizeof_bits<ElementOutput>::value;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
kEpilogueElementsPerAccess, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
hytlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation.
using DepthwiseDirect2dConv = typename hytlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
hytlass::conv::StrideSupport::kFixed,
StrideShape,
DilationShape>::Kernel;
using Direct2dConv = hytlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
int groups;
int splitk;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
std::string tag;
Options()
: help(false),
input_size(1, 128, 128, 32),
filter_size(32, 3, 3, 1),
groups(32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
splitk(1) {}
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors of hytlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// depthwise conv
if (groups != input_size.c()) {
return false;
}
if (filter_size.n() != groups) {
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) {
return false;
}
// Filter size passed through command line does not match filter size template parameter
if (filter_size.h() != FilterShape::kRow || filter_size.w() != FilterShape::kColumn) {
std::cerr << "Filter size passed in (" << filter_size.h() << "x" << filter_size.w() << ") "
<< "must match the FilterShape template parameter of the convolution "
<< "(" << FilterShape::kRow << "x" << FilterShape::kColumn << "). "
<< "To use the filter shape passed in, change the FilterShape template "
<< "parameter and recompile this example."
<< std::endl;
return false;
}
return true;
}
/// Updates input and filter sizes
void update(hytlass::Tensor4DCoord input_size, hytlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
cmd.get_cmd_line_argument("g", groups);
filter_size.c() = 1;
filter_size.n() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("splitk", splitk);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
int32_t padding_h = filter_size.h() / 2;
int32_t padding_w = filter_size.w() / 2;
padding = {padding_h, padding_h, padding_w, padding_w};
}
/// Prints the usage statement.
std::ostream &print_usage(std::ostream &out) const {
out << "12_depthwise_simt_conv2dfprop example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --g=<int> Groups\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --splitk=<int> Enable splitK\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results "
"table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/12_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop --n=32 "
"--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n"
<< "$ ./examples/12_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop --n=1 "
"--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas =
output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result()
: runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) {}
static std::ostream &print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs";
return out;
}
std::ostream &print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
hytlass::Tensor4DCoord output_size = options.output_size();
out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << ","
<< options.input_size.w() << "," << options.input_size.c() << ","
<< options.filter_size.n() << "," << options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column()
<< ","
<< options.dilation.row() << "," << options.dilation.column() << ","
<< options.splitk << ","
<< runtime_ms << "," << gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one testcase
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b_transpose(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0);
// Fill tensor C on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0);
// Fill tensor D on host with zeros
hytlass::reference::host::TensorFill(tensor_d.host_view());
// Fill tensor D for reference on host with zeros
hytlass::reference::host::TensorFill(tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_b_transpose.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for HYTLASS Convolution
//
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Split P*Q into multiple CTA
int split_k_slices = options.splitk;
// Construct Conv2dProblemSize with user defined output size
hytlass::conv::Conv2dProblemSize problem_size(options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices,
options.groups);
// Construct Direc2dConv::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename Direct2dConv::Arguments arguments{problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta},
tensor_b_transpose.device_ref()};
//
// Initialize HYTLASS Convolution
//
Direct2dConv implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
// Compute with reference implementation
hytlass::reference::host::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator >(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha,
options.beta);
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
bool passed =
hytlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view());
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
} else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
} else {
result.reference_check = hytlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "46_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
<< "x" << options.input_size.w() << "x" << options.input_size.c() << "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x"
<< options.filter_size.w() << "x" << options.filter_size.c() << ".dat";
std::ofstream output_workspace(ss.str());
output_workspace << "Input = \n"
<< tensor_a.host_view() << "\n\n"
<< "Filters = \n"
<< tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto &event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error)
<< std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_tensorop_conv2d_bias_relu
gfx928_tensorop_conv2d_bias_relu.cu
)
hytlass_example_add_executable(
gfx928_tensorop_conv2d_bias_add_relu
gfx928_tensorop_conv2d_bias_add_relu.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run convolution kernels using functions and data structures
provided by HYTLASS using tensor cores.
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. HYTLASS solves this problem by providing simplified abstractions to compose
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
of GPU easily.
HYTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
threadblock-tile (tile size computed by a threadblock).
In thie example, we split variable initialization into
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
can view them (logical to physical mapping)
2. Setting up computation properties : describes how the above set tensors will be used to compute
output of convolution.
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In HYTLASS,
the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as
alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit
signed integer. But int4b_t is not fully supported by HYGON software stack, so HYTLASS introduces
hytlass::int4b_t. We use the data type for elements in input tensor A and B as hytlass::int4b_t. We
convey this to HYTLASS kernel by initializing template variables ElementAccumulator (int32_t),
ElementComputeEpilogue (float), ElementInputA (hytlass::int4b_t), ElementInputB (hytlass::int4b_t),
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC hytlass variable. Next, we setup
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
computation of linear combination (alpha * X + beta * C).
Now that we setup the properties of data, we have to setup properties of computation.
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128,
64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate HYTLASS Implicit GEMM kernel, it
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
data in bank-conflict free manner, and ton of other variables required to compose, initialize and
launch a high performance Implicit GEMM kernel. This is the beauty of HYTLASS, it relieves developer
from understanding and coding complicated hardware optimizations which can easily go wrong.
HYTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
constitute the whole process of loading input data from global memory to shared memory, loading data
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
sequence shows a typical mma pipeline.
tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers ->
output to global memory
The problem with single pipeline is, each stage is synchronous which means, each stage has to wait
until the previous finished executing. There are stages in the pipeline which do not have fixed
latency, for example, the loads from global memory and shared memory. Therefore, we can add one more
pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads.
Finally, the pipeline in a kernel looks like
(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5)
mma -> (6) registers -> (7) output to global memory (1) <null> -> (2) <null> -> (3) tensor in global
memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers ->
(9) output to global memory
This way, you can hide the second global memory load latency by doing computation on already loaded
input data.
There are few more template variables initialized such as, which threadblock tile of output matrix
is done which threadblock launched on an SM, GFX architecture of GPU you want to run on.
These are all put together to create a template variable which describes HYTLASS Implicit GEMM
kernel using hytlass::conv::device::ImplicitGemm template.
The next step is to initialize physical data, instantiate and initialize HYTLASS kernel and run it.
We use HYTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
in the way of learning HYTLASS.
Once all the tensors are initialized and filled with data, create arguments tuple to launch HYTLASS
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
important one, split k-dimension factor. Along with that, we query HYTLASS if any scratch-space
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
arguments created to initialize HYTLASS kernel then, the kernel is launched.
In this example, we later on launch a reference convolution kernel (from HYTLASS utilities) to
compare if the output from HYTLASS kernel is same as the reference implicit GEMM kernel.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/epilogue/thread/linear_combination_residual_block.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/device/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<64, 64, 128>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<32, 32, 128>; // Warp tile shape
// This code section describes the size of MMA opS
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 32>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
// 1 -> singlestage
// 2 -> pipelined
constexpr int NumStages = 1;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ElementOutput,
128 / hytlass::sizeof_bits<ElementOutput>::value,
hytlass::epilogue::thread::Identity,
hytlass::plus,
hytlass::epilogue::thread::ReLu>;
using Conv2dFpropKernel = typename hytlass::conv::kernel::DefaultConv2dFpropWithBroadcast<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::IteratorAlgorithm::kOptimized,
hytlass::conv::StrideSupport::kStrided,
8,
8>::Kernel;
using ImplicitGemm = hytlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 14, 14, 1024),
filter_size(256, 1, 1, 1024),
padding(0, 0, 0, 0),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(1),
benchmark(false) { }
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors. Consequently,
// all pointers, strides, and tensor extents must be divisible by 32 elements.
//
// int const kAlignment = 128 / hytlass::sizeof_bits<ElementInputA>::value;
// if ((input_size.c() % kAlignment) ||
// (filter_size.n() % kAlignment)) {
// // misaligned tensors
// return false;
// }
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
hytlass::Tensor4DCoord input_size,
hytlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = false;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
int32_t padding_h = filter_size.h() / 2;
int32_t padding_w = filter_size.w() / 2;
padding = {padding_h, padding_h, padding_w, padding_w};
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "13_hytlass_tensorop_fused_conv2d_fprop example\n\n"
<< " This example uses Turing's Tensor Core operators on int4 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./gfx928_tensorop_conv2d_bias_add_relu --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./gfx928_tensorop_conv2d_bias_add_relu --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias({1, 1, 1, options.output_size().c()});
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_res(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// // Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
4);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
2,
ElementInputB(5),
ElementInputB(-5),
5);
//<- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c_bias.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
0);
// Fill tensor D on host with zeros
hytlass::reference::host::TensorFill(
tensor_d.host_view());
hytlass::reference::host::TensorFillRandomUniform(
tensor_res.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
0);
// Fill tensor D for reference on host with zeros
hytlass::reference::host::TensorFill(
tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c_bias.sync_device();
tensor_d.sync_device();
tensor_res.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for HYTLASS Convolution
//
// mode (kCrossCorrelation or kConvolution)
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
float alpha = 1;
float beta = 1;
// Construct Conv2dProblemSize with user defined output size
hytlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices);
// Construct ImplicitGemm::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_res.device_ref(),
tensor_d.device_ref(),
{alpha,beta},
hytlass::conv::SplitKMode::kNone,
(hytlass::half_t*)(tensor_c_bias.device_data()),
nullptr, 0, options.output_size().c()
};
//
// Initialize HYTLASS Convolution
//
ImplicitGemm implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
// Compute with reference implementation
hytlass::reference::device::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
hytlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c_bias.device_ref(),
tensor_ref_d.device_ref(),
1,
0
);
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
tensor_ref_d.sync_host();
// Compute bias + add+relu in host code
for (int n = 0; n < problem_size.N; ++n) {
for (int p = 0; p < problem_size.P; ++p) {
for (int q = 0; q < problem_size.Q; ++q) {
for (int k = 0; k < problem_size.K; ++k) {
tensor_ref_d.at({n, p, q, k}) =
std::max(ElementOutput(0),
ElementOutput(tensor_ref_d.at({n, p, q, k}) +
tensor_c_bias.at({0, 0, 0, k})+tensor_res.at({n, p, q, k})));
}
}
}
}
const ElementOutput non_zero_floor(1e-6f);
ElementOutput eps(1e-3f);
bool passed = hytlass::reference::host::TensorRelativelyEquals(tensor_d.host_view(), tensor_ref_d.host_view(), eps, non_zero_floor);
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = hytlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "09_tensor_conv_workspace_conv2dfprop_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
struct Benchmark {
int h, w, c, k, r, s;
} layers[] = {
{56, 56, 64, 256, 1, 1},
{56, 56, 64, 64, 1, 1},
{56, 56, 64, 64, 3, 3},
{56, 56, 256, 64, 1, 1},
{56, 56, 256, 512, 1, 1},
{56, 56, 256, 128, 1, 1},
{28, 28, 128, 128, 3, 3},
{28, 28, 128, 512, 1, 1},
{28, 28, 512, 128, 1, 1},
{28, 28, 512, 1024, 1, 1},
{28, 28, 512, 256, 1, 1},
{14, 14, 256, 256, 3, 3},
{14, 14, 256, 1024, 1, 1},
{14, 14, 1024, 256, 1, 1},
{14, 14, 1024, 2048, 1, 1},
{14, 14, 1024, 512, 1, 1},
{7, 7, 512, 512, 3, 3},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run convolution kernels using functions and data structures
provided by HYTLASS using tensor cores.
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. HYTLASS solves this problem by providing simplified abstractions to compose
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
of GPU easily.
HYTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
threadblock-tile (tile size computed by a threadblock).
In thie example, we split variable initialization into
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
can view them (logical to physical mapping)
2. Setting up computation properties : describes how the above set tensors will be used to compute
output of convolution.
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In HYTLASS,
the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as
alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit
signed integer. But int4b_t is not fully supported by HYGON software stack, so HYTLASS introduces
hytlass::int4b_t. We use the data type for elements in input tensor A and B as hytlass::int4b_t. We
convey this to HYTLASS kernel by initializing template variables ElementAccumulator (int32_t),
ElementComputeEpilogue (float), ElementInputA (hytlass::int4b_t), ElementInputB (hytlass::int4b_t),
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC hytlass variable. Next, we setup
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
computation of linear combination (alpha * X + beta * C).
Now that we setup the properties of data, we have to setup properties of computation.
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128,
64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate HYTLASS Implicit GEMM kernel, it
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
data in bank-conflict free manner, and ton of other variables required to compose, initialize and
launch a high performance Implicit GEMM kernel. This is the beauty of HYTLASS, it relieves developer
from understanding and coding complicated hardware optimizations which can easily go wrong.
HYTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
constitute the whole process of loading input data from global memory to shared memory, loading data
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
sequence shows a typical mma pipeline.
tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers ->
output to global memory
The problem with single pipeline is, each stage is synchronous which means, each stage has to wait
until the previous finished executing. There are stages in the pipeline which do not have fixed
latency, for example, the loads from global memory and shared memory. Therefore, we can add one more
pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads.
Finally, the pipeline in a kernel looks like
(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5)
mma -> (6) registers -> (7) output to global memory (1) <null> -> (2) <null> -> (3) tensor in global
memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers ->
(9) output to global memory
This way, you can hide the second global memory load latency by doing computation on already loaded
input data.
There are few more template variables initialized such as, which threadblock tile of output matrix
is done which threadblock launched on an SM, GFX architecture of GPU you want to run on.
These are all put together to create a template variable which describes HYTLASS Implicit GEMM
kernel using hytlass::conv::device::ImplicitGemm template.
The next step is to initialize physical data, instantiate and initialize HYTLASS kernel and run it.
We use HYTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
in the way of learning HYTLASS.
Once all the tensors are initialized and filled with data, create arguments tuple to launch HYTLASS
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
important one, split k-dimension factor. Along with that, we query HYTLASS if any scratch-space
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
arguments created to initialize HYTLASS kernel then, the kernel is launched.
In this example, we later on launch a reference convolution kernel (from HYTLASS utilities) to
compare if the output from HYTLASS kernel is same as the reference implicit GEMM kernel.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_fprop.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/convolution.h"
#include "hytlass/util/reference/device/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<64, 64, 128>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<32, 32, 128>; // Warp tile shape
// This code section describes the size of MMA opS
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 32>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
// 1 -> singlestage
// 2 -> pipelined
constexpr int NumStages = 1;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombinationRelu<
ElementOutput, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue, // Data type for alpha/beta in linear combination
hytlass::epilogue::thread::ScaleType::NoBetaScaling>;
// EpilogueOp::Params params(ElementComputeEpilogue(0.5), ElementComputeEpilogue(0));
using Conv2dFpropKernel = typename hytlass::conv::kernel::DefaultConv2dFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::IteratorAlgorithm::kOptimized,
hytlass::conv::StrideSupport::kStrided,
8,
8
>::Kernel;
using ImplicitGemm = hytlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 14, 14, 1024),
filter_size(256, 1, 1, 1024),
padding(0, 0, 0, 0),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(1),
benchmark(false) { }
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors. Consequently,
// all pointers, strides, and tensor extents must be divisible by 32 elements.
//
// int const kAlignment = 128 / hytlass::sizeof_bits<ElementInputA>::value;
// if ((input_size.c() % kAlignment) ||
// (filter_size.n() % kAlignment)) {
// // misaligned tensors
// return false;
// }
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
hytlass::Tensor4DCoord input_size,
hytlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = false;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
int32_t padding_h = filter_size.h() / 2;
int32_t padding_w = filter_size.w() / 2;
padding = {padding_h, padding_h, padding_w, padding_w};
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "13_tensorop_fused_conv2d_fprop example\n\n"
<< " This example uses Turing's Tensor Core operators on int4 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./gfx928_tensorop_conv2d_bias_relu --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./gfx928_tensorop_conv2d_bias_relu --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias({1, 1, 1, options.output_size().c()});
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
hytlass::MantissaInBits<ElementOutput>::bits);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
2,
ElementInputB(4),
ElementInputB(-4),
hytlass::MantissaInBits<ElementOutput>::bits);
// <- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c_bias.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
0);
// Fill tensor D on host with zeros
hytlass::reference::host::TensorFill(
tensor_d.host_view());
// Fill tensor D for reference on host with zeros
hytlass::reference::host::TensorFill(
tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c_bias.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for HYTLASS Convolution
//
// mode (kCrossCorrelation or kConvolution)
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Construct Conv2dProblemSize with user defined output size
hytlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices);
// Construct ImplicitGemm::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
{tensor_c_bias.device_data(), LayoutOutput::Stride(0)},
tensor_d.device_ref(),
{alpha,beta},
};
//
// Initialize HYTLASS Convolution
//
ImplicitGemm implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
// Compute with reference implementation
hytlass::reference::device::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
hytlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c_bias.device_ref(),
tensor_ref_d.device_ref(),
alpha,
beta
);
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
tensor_ref_d.sync_host();
// Compute bias + relu in host code
for (int n = 0; n < problem_size.N; ++n) {
for (int p = 0; p < problem_size.P; ++p) {
for (int q = 0; q < problem_size.Q; ++q) {
for (int k = 0; k < problem_size.K; ++k) {
tensor_ref_d.at({n, p, q, k}) =
std::max(ElementOutput(0),
ElementOutput(tensor_ref_d.at({n, p, q, k}) +
tensor_c_bias.at({0, 0, 0, k})));
}
}
}
}
const ElementOutput non_zero_floor(1e-6);
ElementOutput eps(1e-3);
bool passed = hytlass::reference::host::TensorRelativelyEquals(tensor_d.host_view(), tensor_ref_d.host_view(), eps, non_zero_floor);
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = hytlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "09_tensor_conv_workspace_conv2dfprop_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
struct Benchmark {
int h, w, c, k, r, s;
} layers[] = {
{56, 56, 64, 256, 1, 1},
{56, 56, 64, 64, 1, 1},
{56, 56, 64, 64, 3, 3},
{56, 56, 256, 64, 1, 1},
{56, 56, 256, 512, 1, 1},
{56, 56, 256, 128, 1, 1},
{28, 28, 128, 128, 3, 3},
{28, 28, 128, 512, 1, 1},
{28, 28, 512, 128, 1, 1},
{28, 28, 512, 1024, 1, 1},
{28, 28, 512, 256, 1, 1},
{14, 14, 256, 256, 3, 3},
{14, 14, 256, 1024, 1, 1},
{14, 14, 1024, 256, 1, 1},
{14, 14, 1024, 2048, 1, 1},
{14, 14, 1024, 512, 1, 1},
{7, 7, 512, 512, 3, 3},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gather_k_fusion
gather_k_fusion.cu
)
hytlass_example_add_executable(
gather_scatter_fusion
gather_scatter_fusion.cu
)
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// This example fuses gather before GEMM and scatter after GEMM into the same
// GEMM kernel. Gather and scatter operation is controled by an index vector
// to select rows or columns from A, B, C or D matrices.
//
// Suppose, all matrices are column major. The pseudo code of the fused kernel
// in this example is essentially
//
// for (int i = 0; i < problem_size.m(); ++i) {
// for (int j = 0; j < options.index_size; ++j) {
// int b_c_d_col = tensor_indices.at({j, 0});
//
// for (int k = 0; k < options.index_size; ++k) {
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
// }
// }
//
// Note that the index vector contains unique random integers with max to be N - 1
//
// The gather/scatter operation works best when we can still keep the biggest
// alignment. For example, when the matrix is row major, we select rows. When
// the matrix is column major, we select columns.
//
// Not all the combination of gather and scatter are legal. For example, if A is
// row major and C/D is column major, we cannot gather A and scatter C/D at the
// same time.
//
// Also, we don't check the index value is legal and index array point is valid
// for the sake of the performance.
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>
#include <assert.h>
#include <hip/hip_runtime.h>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <random>
#include <numeric>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm_universal.h"
#include "hytlass/epilogue/thread/linear_combination.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
int index_size;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({1024, 1024, 1024}),
index_size(240),
reference_check(true),
iterations(50)
{}
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("index_size", index_size);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "14_gather_scatter_fusion example\n\n"
<< " This example uses the HYTLASS Library to fuse gather/scatter into GEMM\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --index_size=<int> size of N dimension index\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/14_gather_scatter_fusion/gather_k_fusion --m=1024 --n=512 --k=1024 \\\n"
<< " --index_size=128\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.m() * problem_size.n() * int64_t(index_size);
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = hytlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = hytlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices.
// Column Major for Matrix A, B and C.
//
using LayoutInputA = hytlass::layout::ColumnMajor;
using LayoutInputB = hytlass::layout::RowMajor;
using LayoutOutput = hytlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
hytlass::gemm::GemmShape<128, 256, 32>;
// This code section describes tile size a warp will compute
using ShapeMMAWarp = hytlass::gemm::GemmShape<64, 128, 32>;
// This code section describes the size of MMA op
using ShapeMMAOp = hytlass::gemm::GemmShape<16, 16, 16>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// Define the epilogue operation as LinearCombination. This is approximately equal to
//
// d_ij = alpha * sum_k(a_ik * b_kj) + c_ij
//
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / hytlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 1;
using Gemm = hytlass::gemm::device::GemmUniversal<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
8, /*alignmentA*/
8, /*alignmentB*/
hytlass::arch::OpMultiplyAdd,
hytlass::ComplexTransform::kNone,
hytlass::ComplexTransform::kNone,
true, /*GatherA*/
true, /*GatherB*/
false, /*ScatterD*/
hytlass::layout::NoPermute,
hytlass::layout::NoPermute,
hytlass::layout::NoPermute,
true /*buffer load*/
>;
int run(Options &options) {
// ================================================================================
// Initialization setup
// Create a tuple of problem size for matrix multiplication
hytlass::gemm::GemmCoord problem_size = options.problem_size;
// Create a tuple of problem size for matrix multiplication
hytlass::gemm::GemmCoord problem_size_real(problem_size.m(),
problem_size.n(),
options.index_size);
// Initialize tensors using HYTLASS helper functions
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// HYTLASS kernel
// Fill input and output matrices on host using HYTLASS helper functions
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(3),
ElementInputA(-3),
0); // <- Fill matrix A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputA(3),
ElementInputA(-3),
0); // <- Fill matrix B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0); // <- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
hytlass::HostTensor<int, LayoutOutput> tensor_indices(
{options.index_size, 1}); // <- Create scatter indices with dimensions val_len x 1
// <- Fill tensor_k_indices on host with unique random integers
std::vector<int> to_fill(problem_size.k()); // vector with ints.
std::iota(std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n()
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
std::random_device make_seed;
std::mt19937 source_of_randomness(make_seed());
std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness);
}
memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int));
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_indices.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
// Initialize alpha/beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated HYTLASS kernel
typename Gemm::Arguments arguments{
hytlass::gemm::GemmUniversalMode::kGemm,
problem_size_real, // <- problem size of matrix multiplication
split_k_slices, // <- k-dimension split factor
{alpha, beta}, // <- alpha, beta
tensor_a.device_data(), // <- reference to matrix A on device
tensor_b.device_data(), // <- reference to matrix B on device
tensor_c.device_data(), // <- reference to matrix C on device
tensor_d.device_data(), // <- reference to matrix D on device
tensor_a.layout().capacity(hytlass::make_Coord(problem_size.m(), options.index_size)),
tensor_b.layout().capacity(hytlass::make_Coord(options.index_size, problem_size.n())),
tensor_c.layout().capacity(problem_size.mn()),
tensor_d.layout().capacity(problem_size.mn()),
tensor_a.layout().stride(),
tensor_b.layout().stride(),
tensor_c.layout().stride(),
tensor_d.layout().stride(),
tensor_indices.device_data(), // <- pointer to index vector to gather A on device
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
nullptr}; // <- pointer to index vector to scatter D on device
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate HYTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
hytlass::Status status = gemm_op.can_implement(arguments);
HYTLASS_CHECK(status);
// Initialize HYTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(status);
// CPU reference calculation
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_ref(problem_size.mn());
hytlass::reference::host::TensorFill(
tensor_d_ref.host_view()); // <- Fill matrix D on host with zeros
status = gemm_op();
(void)hipDeviceSynchronize();
HYTLASS_CHECK(status);
if (options.reference_check) {
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < problem_size.n(); ++j) {
for (int k = 0; k < options.index_size; ++k) {
int _k = tensor_indices.at({k, 0});
tensor_d_ref.at({i, j}) += (alpha * tensor_a.at({i, _k}) * tensor_b.at({_k, j}));
}
tensor_d_ref.at({i, j}) += (beta * tensor_c.at({i, j}));
}
}
// Copy output data from HYTLASS and reference kernel to host for comparison
tensor_d.sync_host();
bool passed = hytlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_d_ref.host_view());
if (!passed) {
std::cout << "Failed!\n";
std::stringstream fname;
fname << "error_gather_k_GEMM_fusion.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A =\n" << tensor_a.host_view()
<< "\nB =\n" << tensor_b.host_view()
<< "\nindices =\n" << tensor_indices.host_view()
<< "\nC =\n" << tensor_c.host_view()
<< "\n\nReference =\n" << tensor_d_ref.host_view()
<< "\nComputed =\n" << tensor_d.host_view();
return -1;
} else {
std::cout << "Passed!\n";
}
}
// Result structure
Result result;
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// warm up
for (int iter = 0; iter < 10; ++iter) {
status = gemm_op();
HYTLASS_CHECK(status);
}
// Record an event at the start of a series of GEMMs
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized HYTLASS kernel
status = gemm_op();
HYTLASS_CHECK(status);
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
std::cout << "Runtime: " << result.runtime_ms << " ms\n";
std::cout << "GFLOPs: " << result.gflops << "\n";
return 0;
}
int main(int argc, const char ** argv) {
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << "\n";
return 0;
}
if (!options.valid()) {
std::cerr << "Invalid problem." << "\n";
return -1;
}
return run(options);
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// This example fuses gather before GEMM and scatter after GEMM into the same
// GEMM kernel. Gather and scatter operation is controled by an index vector
// to select rows or columns from A, B, C or D matrices.
//
// Suppose, all matrices are column major. The pseudo code of the fused kernel
// in this example is essentially
//
// for (int i = 0; i < problem_size.m(); ++i) {
// for (int j = 0; j < options.index_size; ++j) {
// int b_c_d_col = tensor_indices.at({j, 0});
//
// for (int k = 0; k < options.index_size; ++k) {
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
// }
// }
//
// Note that the index vector contains unique random integers with max to be N - 1
//
// The gather/scatter operation works best when we can still keep the biggest
// alignment. For example, when the matrix is row major, we select rows. When
// the matrix is column major, we select columns.
//
// Not all the combination of gather and scatter are legal. For example, if A is
// row major and C/D is column major, we cannot gather A and scatter C/D at the
// same time.
//
// Also, we don't check the index value is legal and index array point is valid
// for the sake of the performance.
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <math.h>
#include <assert.h>
#include <hip/hip_runtime.h>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <random>
#include <numeric>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm_universal.h"
#include "hytlass/epilogue/thread/linear_combination.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
int index_size;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({248, 1024, 1024}),
index_size(240),
reference_check(true),
iterations(50)
{}
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("index_size", index_size);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "14_gather_scatter_fusion example\n\n"
<< " This example uses the HYTLASS Library to fuse gather/scatter into GEMM\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --index_size=<int> size of N dimension index\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/14_gather_scatter_fusion/gather_scatter_fusion --m=1024 --n=512 --k=1024 \\\n"
<< " --index_size=128\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.m() * int64_t(index_size) * problem_size.k();
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = hytlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = hytlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices.
// Column Major for Matrix A, B and C.
//
using LayoutInputA = hytlass::layout::ColumnMajor;
using LayoutInputB = hytlass::layout::ColumnMajor;
using LayoutOutput = hytlass::layout::ColumnMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
hytlass::gemm::GemmShape<128, 256, 32>;
// This code section describes tile size a warp will compute
using ShapeMMAWarp = hytlass::gemm::GemmShape<64, 128, 32>;
// This code section describes the size of MMA op
using ShapeMMAOp = hytlass::gemm::GemmShape<16, 16, 16>;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// Define the epilogue operation as LinearCombination. This is approximately equal to
//
// d_ij = alpha * sum_k(a_ik * b_kj) + c_ij
//
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / hytlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 1;
using Gemm = hytlass::gemm::device::GemmUniversal<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
8, /*alignmentA*/
8, /*alignmentB*/
hytlass::arch::OpMultiplyAdd,
hytlass::ComplexTransform::kNone,
hytlass::ComplexTransform::kNone,
false, /*GatherA*/
true, /*GatherB*/
true, /*ScatterD*/
hytlass::layout::NoPermute,
hytlass::layout::NoPermute,
hytlass::layout::NoPermute,
true>;
int run(Options &options) {
// ================================================================================
// Initialization setup
// Create a tuple of problem size for matrix multiplication
hytlass::gemm::GemmCoord problem_size = options.problem_size;
// Create a tuple of problem size for matrix multiplication
hytlass::gemm::GemmCoord problem_size_real(problem_size.m(),
options.index_size,
problem_size.k());
// Initialize tensors using HYTLASS helper functions
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// HYTLASS kernel
// Fill input and output matrices on host using HYTLASS helper functions
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(3),
ElementInputA(-3),
0); // <- Fill matrix A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputA(3),
ElementInputA(-3),
0); // <- Fill matrix B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0); // <- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFill(
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
hytlass::HostTensor<int, LayoutOutput> tensor_indices(
{options.index_size, 1}); // <- Create scatter indices with dimensions val_len x 1
// <- Fill tensor_b_indices on host with unique random integers
std::vector<int> to_fill(problem_size.n()); // vector with ints.
std::iota(std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n()
{ // std::random_shuffle was deprecated in C++14 and removed in C++17
std::random_device make_seed;
std::mt19937 source_of_randomness(make_seed());
std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness);
}
memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int));
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_indices.sync_device();
tensor_c.sync_device();
tensor_d_scattered.sync_device();
// Initialize alpha/beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(1);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated HYTLASS kernel
typename Gemm::Arguments arguments{
hytlass::gemm::GemmUniversalMode::kGemm,
problem_size_real, // <- problem size of matrix multiplication
split_k_slices, // <- k-dimension split factor
{alpha, beta}, // <- alpha, beta
tensor_a.device_data(), // <- reference to matrix A on device
tensor_b.device_data(), // <- reference to matrix B on device
tensor_c.device_data(), // <- reference to matrix C on device
tensor_d_scattered.device_data(), // <- reference to matrix D on device
tensor_a.layout().capacity(problem_size.mk()),
tensor_b.layout().capacity(hytlass::make_Coord(options.index_size, problem_size.k())),
tensor_c.layout().capacity(problem_size.mn()),
tensor_d_scattered.layout().capacity(problem_size.mn()),
tensor_a.layout().stride(),
tensor_b.layout().stride(),
tensor_c.layout().stride(),
tensor_d_scattered.layout().stride(),
nullptr, // <- pointer to index vector to gather A on device
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate HYTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
hytlass::Status status = gemm_op.can_implement(arguments);
HYTLASS_CHECK(status);
// Initialize HYTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(status);
// CPU reference calculation
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_ref(problem_size.mn());
hytlass::reference::host::TensorFill(
tensor_d_ref.host_view()); // <- Fill matrix D on host with zeros
status = gemm_op();
(void)hipDeviceSynchronize();
HYTLASS_CHECK(status);
if (options.reference_check) {
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < options.index_size; ++j) {
int b_c_d_col = tensor_indices.at({j, 0});
for (int k = 0; k < problem_size.k(); ++k) {
tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
}
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
}
}
// Copy output data from HYTLASS and reference kernel to host for comparison
tensor_d_scattered.sync_host();
bool passed = hytlass::reference::host::TensorEquals(
tensor_d_scattered.host_view(),
tensor_d_ref.host_view());
if (!passed) {
std::cout << "Failed!\n";
std::stringstream fname;
fname << "error_gather_GEMM_scatter_fusion.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A =\n" << tensor_a.host_view()
<< "\nB =\n" << tensor_b.host_view()
<< "\nindices =\n" << tensor_indices.host_view()
<< "\nC =\n" << tensor_c.host_view()
<< "\n\nReference =\n" << tensor_d_ref.host_view()
<< "\nComputed =\n" << tensor_d_scattered.host_view();
return -1;
} else {
std::cout << "Passed!\n";
}
}
// Result structure
Result result;
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// warm up
for (int iter = 0; iter < 10; ++iter) {
status = gemm_op();
HYTLASS_CHECK(status);
}
// Record an event at the start of a series of GEMMs
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized HYTLASS kernel
status = gemm_op();
HYTLASS_CHECK(status);
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
std::cout << "Runtime: " << result.runtime_ms << " ms\n";
std::cout << "GFLOPs: " << result.gflops << "\n";
return 0;
}
int main(int argc, const char ** argv) {
bool notSupported = false;
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << "\n";
return 0;
}
if (!options.valid()) {
std::cerr << "Invalid problem." << "\n";
return -1;
}
return run(options);
}
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
hute_gfx928_group_gemm
hute_gfx928_group_gemm.cu
)
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2023 - 2023 HYGON CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped GEMM example using HYTLASS 3 APIs for DCU architecture.
For this example all scheduling work is performed on the device.
To run this example:
$ ./hute_gfx928_group_gemm --m=2048 --n=2048 --k=2048 --groups=10
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Skipping any of the problem dimensions randomizes it across the different groups.
Same applies for alpha and beta values that are randomized across the different groups.
To run this example for a set of problems using the benchmark option:
$ ./hute_gfx928_group_gemm --benchmark=./test_benchmark.txt
Where the test_benchmark.txt may look as such:
0 256x512x128
1 256x512x512
2 512x256x128
3 256x256x128
4 256x512x1024
5 1024x512x128 and so on
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <float.h>
#include "hytlass/hytlass.h"
#include "hute/tensor.hpp"
#include "hytlass/tensor_ref.h"
#include "hytlass/epilogue/collective/default_epilogue.hpp"
#include "hytlass/epilogue/thread/linear_combination.h"
#include "hytlass/gemm/dispatch_policy.hpp"
#include "hytlass/gemm/group_array_problem_shape.hpp"
#include "hytlass/gemm/collective/collective_builder.hpp"
#include "hytlass/epilogue/collective/collective_builder.hpp"
#include "hytlass/gemm/device/gemm_universal_adapter.h"
#include "hytlass/gemm/kernel/gemm_universal.hpp"
#include "hytlass/util/command_line.h"
#include "hytlass/util/distribution.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/packed_stride.hpp"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/device/tensor_compare.h"
#include "hytlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace hute;
using ProblemShape = hytlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementA = hytlass::half_t; // Element type for A matrix operand
using ElementB = hytlass::half_t; // Element type for B matrix operand
using ElementC = hytlass::half_t; // Element type for C and D matrix operands
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using LayoutA = hytlass::layout::ColumnMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using LayoutB = hytlass::layout::RowMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using LayoutC = hytlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / hytlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = hytlass::arch::Gfx928; // Tag indicating the minimum Gfx that supports the intended feature
using OperatorClass = hytlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>;
using WarpShape_MNK = Shape<_2, _2, _1>;
using InstructionShape_MNK = Shape<_32, _32, _16>;
using StageCountType = hytlass::gemm::collective::StageCount<2>;
using KernelSchedule = hytlass::gemm::KernelPtrArraySpecialized; // Kernel to launch
using EpilogueSchedule = hytlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename hytlass::epilogue::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
hytlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename hytlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
TileShape, WarpShape_MNK, InstructionShape_MNK, ClusterShape,
StageCountType,
KernelSchedule
>::CollectiveOp;
using GemmKernel = hytlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = hytlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = hytlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
// Host-side allocations
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<StrideA> stride_A_host;
std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
// Device-side allocations
hytlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
hytlass::DeviceAllocation<typename Gemm::ElementA> block_A;
hytlass::DeviceAllocation<typename Gemm::ElementB> block_B;
hytlass::DeviceAllocation<typename Gemm::ElementC> block_C;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
hytlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
hytlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
hytlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
hytlass::DeviceAllocation<StrideA> stride_A;
hytlass::DeviceAllocation<StrideB> stride_B;
hytlass::DeviceAllocation<StrideC> stride_C;
hytlass::DeviceAllocation<StrideD> stride_D;
// Note, this is an array of pointers to alpha and beta scaling values per group
hytlass::DeviceAllocation<ElementAccumulator*> alpha_device;
hytlass::DeviceAllocation<ElementAccumulator*> beta_device;
hytlass::DeviceAllocation<ElementAccumulator> block_alpha;
hytlass::DeviceAllocation<ElementAccumulator> block_beta;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = FLT_MAX;
float beta = FLT_MAX;
int iterations = 10;
int m = 1024, n = 2048, k = 512, groups = 10;
std::string benchmark_path;
std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host;
int const tma_alignment_bits = 128;
int const alignment = tma_alignment_bits / hytlass::sizeof_bits<ElementA>::value;
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
// Decide how to initialize the problems
if (!benchmark_path.empty()) {
if (!benchmark_problems()) {
problem_sizes_host.clear();
return;
}
}
else {
randomize_problems(cmd);
}
}
void randomize_problems(hytlass::CommandLine &cmd) {
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = ((rand() % 512) + 1);
}
if (n < 1) {
n = ((rand() % 512) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
}
problem_sizes_host.push_back({m, n, k});
}
}
/// Load a benchmark
bool benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file.good()) {
return false;
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
hytlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
hytlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
}
groups = static_cast<int>(problem_sizes_host.size());
return true;
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "hute_group_gemm\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM for all groups\n"
<< " --n=<int> Sets the N extent of the GEMM for all groups\n"
<< " --k=<int> Sets the K extent of the GEMM for all groups\n"
<< " --groups=<int> Sets the number of individual GEMM problems for Grouped GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "hute_gfx928_group_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s, std::vector<typename ProblemShape::UnderlyingProblemShape> problem_sizes_host) const
{
// Number of real-valued multiply-adds
uint64_t fmas = uint64_t();
for (auto const & problem : problem_sizes_host) {
fmas += static_cast<uint64_t>(get<0>(problem)) *
static_cast<uint64_t>(get<1>(problem)) *
static_cast<uint64_t>(get<2>(problem));
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
hytlass::Status status = hytlass::Status::kSuccess;
hipError_t error = hipSuccess;
bool passed = false;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
hytlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = static_cast<Element>(1);
scope_min = static_cast<Element>(1);
} else if (bits_input <= 8) {
scope_max = static_cast<Element>(1);
scope_min = static_cast<Element>(1);
} else {
scope_max = static_cast<Element>(4);
scope_min = static_cast<Element>(2);
}
hytlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Allocates device-side data
void allocate(const Options &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
int64_t elements_A = M * K;
int64_t elements_B = K * N;
int64_t elements_C = M * N;
int64_t elements_D = M * N;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
stride_A_host.push_back(hytlass::make_hute_packed_stride(StrideA{}, hute::make_shape(M, K, Int<1>{})));
stride_B_host.push_back(hytlass::make_hute_packed_stride(StrideB{}, hute::make_shape(N, K, Int<1>{})));
stride_C_host.push_back(hytlass::make_hute_packed_stride(StrideC{}, hute::make_shape(M, N, Int<1>{})));
stride_D_host.push_back(hytlass::make_hute_packed_stride(StrideD{}, hute::make_shape(M, N, Int<1>{})));
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
uint64_t seed = 2020;
problem_sizes.reset(options.groups);
problem_sizes.copy_from_host(options.problem_sizes_host.data());
//
// Assign pointers
//
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.groups);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.groups);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.groups);
ptr_D.copy_from_host(ptr_D_host.data());
stride_A.reset(options.groups);
stride_A.copy_from_host(stride_A_host.data());
stride_B.reset(options.groups);
stride_B.copy_from_host(stride_B_host.data());
stride_C.reset(options.groups);
stride_C.copy_from_host(stride_C_host.data());
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
{
hytlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = hytlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::EpilogueOutputOp::Params params;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
params = typename Gemm::EpilogueOutputOp::Params(
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
}
typename Gemm::Arguments arguments;
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
hytlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
arguments = typename Gemm::Arguments {
hytlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
return arguments;
}
bool verify(const Options &options) {
bool passed = true;
for (int32_t i = 0; i < options.groups; ++i) {
auto problem = options.problem_sizes_host.at(i);
auto M = get<0>(problem);
auto N = get<1>(problem);
auto K = get<2>(problem);
hytlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K}));
hytlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N}));
hytlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N}));
hytlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{M, N, K},
ElementAccumulator(alpha_host.at(i)),
ref_A,
ref_B,
ElementAccumulator(beta_host.at(i)),
ref_C,
ref_D);
// Wait for kernel to finish
HIP_CHECK(hipDeviceSynchronize());
// Check if output from HYTLASS kernel and reference kernel are equal or not
passed &= hytlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N);
#if 0
std::cout << "Group: " << i << " Status: " << passed << std::endl;
#endif
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
// Instantiate HYTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
HYTLASS_CHECK(gemm.can_implement(arguments));
// Initialize HYTLASS kernel with arguments and workspace pointer
HYTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
HYTLASS_CHECK(gemm.run());
// Check if output from HYTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
HYTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
HYTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average setup and runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
}
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate HYTLASS kernels
//
run<Gemm>(options);
run<Gemm>(options, false /*host_problem_shapes_available*/);
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gemm_softmax
gemm_softmax.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
*/
#include <cmath>
#include <iostream>
#include <vector>
#include <limits>
#include "hytlass/hytlass.h"
#include "hytlass/arch/memory.h"
#include "hytlass/arch/memory_gfx928.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/host/gemm_complex.h"
#include "hytlass/util/reference/device/gemm_complex.h"
#include "hytlass/util/reference/host/tensor_reduce.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_norm.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/device/tensor_fill.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/error_metrics.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/layout/matrix.h"
#include "hytlass/epilogue/thread/linear_combination.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "gemm_with_softmax.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#define TRACE(x) { std::cout << "gemm_softmax.cu:" << __LINE__ << " " << x << std::endl; }
/////////////////////////////////////////////////////////////////////////////////////////////////
enum class Disposition {
kPassed,
kIncorrect,
kNotVerified
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
int batch_count;
int iterations;
unsigned seed;
float alpha;
float beta;
bool verification_enabled;
float tolerance;
Options():
help(false),
problem_size({16, 24, 64}),
batch_count(16),
iterations(20),
seed(2022),
alpha(1),
beta(0),
verification_enabled(true),
tolerance(1e-5f)
{ }
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("batch_count", batch_count);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("verify", verification_enabled);
cmd.get_cmd_line_argument("seed", seed);
cmd.get_cmd_line_argument("tolerance", tolerance);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "16_gemm_softmax example\n\n"
<< " This example uses the HYTLASS Library to compute GEMM + Softmax for arbitrary problem sizes.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --batch_count=<int> Batch number\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --seed=<int> Random number seed (1*)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform (0 to disable profiling).\n\n"
<< " --verify=<bool> If true, performs reference calculation.\n\n"
<< " --tolerance <float> Error tolerance\n"
;
out << "\n\nExamples:\n\n"
<< "$ ./gemm_softmax --m=1024 --n=512 \\\n"
<< " --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Returns true if the environment and Toolkit support this
bool supported(bool verbose = true) const {
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Testbed {
//
// Type definitions
//
using ElementA = hytlass::half_t;
using ElementB = hytlass::half_t;
using ElementC = hytlass::half_t;
using ElementCompute = float;
using ElementD = ElementC;
using ElementSoftmax = ElementC;
using LayoutA = hytlass::layout::RowMajor;
using LayoutB = hytlass::layout::ColumnMajor;
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>;
using OperatorClass = hytlass::arch::OpClassTensorOp;
using ArchTag = hytlass::arch::Gfx928;
// ApplyShape impacts the final Softmax performance a lot.
// Set ApplyShape::kColumn to be the next multiple of 32 number that is after
// (gemm_N / alignment).
// Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn).
using ApplyShape = hytlass::MatrixShape<1, 1024>;
static int const kStages = 1;
/// Linear scaling operator
using EpilogueFunctorOp = hytlass::epilogue::thread::LinearCombination<
ElementC,
128 / hytlass::sizeof_bits<ElementC>::value,
ElementCompute,
ElementCompute
>;
using GemmSoftmax = hytlass::GemmSoftmax<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC,
ElementCompute,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueFunctorOp,
kStages,
ApplyShape
>;
using ElementNorm = typename GemmSoftmax::ElementNorm;
using ElementSum = typename GemmSoftmax::ElementSum;
using LayoutC = typename GemmSoftmax::LayoutC;
using LayoutN = typename GemmSoftmax::LayoutN;
using LayoutS = typename GemmSoftmax::LayoutS;
using MatrixCoord = typename LayoutC::TensorCoord;
//
// Data members
//
Options const &options;
hytlass::HostTensor<ElementNorm, LayoutC> reference_N;
hytlass::DeviceAllocation<ElementA> block_A;
hytlass::DeviceAllocation<ElementB> block_B;
hytlass::DeviceAllocation<ElementC> block_C;
hytlass::DeviceAllocation<ElementD> block_D;
hytlass::DeviceAllocation<ElementD> block_Ref;
hytlass::DeviceAllocation<ElementSoftmax> block_Softmax;
hytlass::DeviceAllocation<ElementNorm> block_Norm;
hytlass::DeviceAllocation<ElementSum> block_Sum;
int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN;
hytlass::gemm::GemmCoord problem = options.problem_size;
int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0);
int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0);
int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0);
// fixed rowmajor for norm and sum
int64_t ldn = problem.m();
int64_t lds = ldn;
int64_t total_elements_A_per_batch = problem.m() * problem.k();
int64_t total_elements_B_per_batch = problem.k() * problem.n();
int64_t total_elements_C_per_batch = problem.m() * problem.n();
int64_t total_elements_D_per_batch = problem.m() * problem.n();
int64_t total_elements_partial_norm_per_batch = block_num * problem.m();
int64_t total_elements_A = total_elements_A_per_batch * options.batch_count;
int64_t total_elements_B = total_elements_B_per_batch * options.batch_count;
int64_t total_elements_C = total_elements_C_per_batch * options.batch_count;
int64_t total_elements_D = total_elements_D_per_batch * options.batch_count;
int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count;
//
// Methods
//
Testbed(
Options const &options_
):
options(options_)
{
reference_N.reset({options.problem_size.m(), 1}, false);
}
/// Run
Disposition run() {
Disposition disposition = Disposition::kNotVerified;
//
// Initialize the workspace
//
initialize();
//
// Launch device kernel
//
hytlass::Status status = hytlass::Status::kSuccess;
status = execute_device_kernel();
if (status != hytlass::Status::kSuccess) {
std::cerr << "Device execution failed." << std::endl;
return disposition;
}
hipError_t result = hipDeviceSynchronize();
if (result != hipSuccess) {
std::cerr << "Device synchronize failed with error "
<< hipGetErrorString(result) << std::endl;
return disposition;
}
//
// Verify
//
if (options.verification_enabled) {
bool passed = verify();
if (passed) {
disposition = Disposition::kPassed;
}
else {
disposition = Disposition::kIncorrect;
}
}
//
// Profiling
//
if (options.iterations) {
profile();
}
return disposition;
}
/// Random initialization
void initialize() {
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_Softmax.reset(total_elements_D);
block_Ref.reset(total_elements_D_per_batch);
block_Norm.reset(total_elements_partial_norm);
block_Sum.reset(total_elements_partial_norm);
hytlass::reference::device::BlockFillRandomUniform(
block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0);
hytlass::reference::device::BlockFillRandomUniform(
block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0);
hytlass::reference::device::BlockFillRandomUniform(
block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0);
hytlass::reference::device::BlockFillRandomUniform(
block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0);
hytlass::reference::device::BlockFillRandomUniform(
block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0);
hytlass::reference::device::BlockFillRandomUniform(
block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0);
hytlass::reference::host::TensorFill(
reference_N.host_view(),
ElementNorm()
);
}
hytlass::Status execute_device_kernel() {
hytlass::Status status = hytlass::Status::kSuccess;
//
// Setup arguments
//
GemmSoftmax::Arguments args(
options.problem_size,
options.batch_count,
{block_A.get(), lda},
{block_B.get(), ldb},
{block_C.get(), ldc},
{block_D.get(), ldc},
{
ElementCompute(options.alpha),
ElementCompute(options.beta)
},
{block_Norm.get(), ldn},
{block_Sum.get(), lds},
{block_Softmax.get(), ldc},
total_elements_A_per_batch,
total_elements_B_per_batch,
total_elements_C_per_batch,
total_elements_D_per_batch,
total_elements_partial_norm_per_batch,
total_elements_partial_norm_per_batch,
total_elements_D_per_batch
);
//
// Launch
//
GemmSoftmax gemm_softmax;
// Initialize
status = gemm_softmax.initialize(args);
if (status != hytlass::Status::kSuccess) {
return status;
}
// Run
status = gemm_softmax();
return status;
}
template<typename Element>
bool verify_tensor(std::vector<Element> vector_Input, \
std::vector<Element> vector_Input_Ref) {
auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size());
float abs_tol = options.tolerance;
float rel_tol = options.tolerance;
for (int64_t i = 0; i < size; ++i) {
float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i));
float abs_diff = fabs(diff);
float abs_ref = fabs((float)vector_Input_Ref.at(i));
float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0;
if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) {
printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i)));
return false;
}
}
return true;
}
/// Verifies the reference matches
bool verify() {
LayoutA layout_A(lda);
LayoutB layout_B(ldb);
LayoutC layout_C(ldc);
LayoutN Layout_N(ldn);
LayoutS Layout_S(lds);
MatrixCoord extent_A{problem.m(), problem.k()};
MatrixCoord extent_B{problem.k(), problem.n()};
MatrixCoord extent_C{problem.m(), problem.n()};
for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) {
hytlass::TensorView<ElementA, LayoutA> view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A);
hytlass::TensorView<ElementB, LayoutB> view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B);
hytlass::TensorView<ElementC, LayoutC> view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C);
hytlass::TensorView<ElementC, LayoutC> view_Ref_device(block_Ref.get(), layout_C, extent_C);
hytlass::reference::device::GemmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute, ElementCompute
>(
problem,
options.alpha,
view_A,
hytlass::ComplexTransform::kNone,
view_B,
hytlass::ComplexTransform::kNone,
options.beta,
view_C,
view_Ref_device,
ElementCompute(0)
);
// Copy reference results to host memory for verification
std::vector<ElementD> matrix_D_Ref(layout_C.capacity(extent_C));
hytlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size());
hytlass::TensorView<ElementD, LayoutC> view_Ref(matrix_D_Ref.data(), layout_C, extent_C);
std::vector<ElementSoftmax> matrix_Softmax_Ref(layout_C.capacity(extent_C));
hytlass::TensorView<ElementSoftmax, LayoutC> view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C);
// Copy computed results to host memory
std::vector<ElementD> matrix_D(layout_C.capacity(extent_C));
hytlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size());
std::vector<ElementD> matrix_Softmax(layout_C.capacity(extent_C));
hytlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size());
// Compute the norm
for (int m = 0; m < options.problem_size.m(); ++m) {
reference_N.at({m, 0}) = view_Ref.ref().at({m, 0});
for (int n = 1; n < options.problem_size.n(); ++n) {
reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n})));
}
}
// Compute softmax
for (int m = 0; m < options.problem_size.m(); ++m) {
float sum = float();
for (int n = 0; n < options.problem_size.n(); ++n) {
sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) );
}
float inv_sum = float(1.0f / sum);
for (int n = 0; n < options.problem_size.n(); ++n) {
view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax(
std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum
);
}
}
// Verification checks - set any of these to 'true' to override the verification checks.
bool verified_D = true;
bool verified_Softmax = true;
// Verify softmax output
if (!verified_D) {
verified_D = verify_tensor<ElementC>(matrix_D, matrix_D_Ref);
}
if (!verified_Softmax) {
verified_Softmax = verify_tensor<ElementSoftmax>(matrix_Softmax, matrix_Softmax_Ref);
}
if (!verified_D || !verified_Softmax) {
std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n";
// Summarize which checks failed
if (!verified_D) {
std::cerr << "Verification of D tensor failed\n";
}
if (!verified_Softmax) {
std::cerr << "Verification of Softmax tensor failed\n";
}
return false;
}
}
return true;
}
/// Profiles
bool profile() {
//
// Profile
//
hytlass::Status status = hytlass::Status::kSuccess;
hipError_t result;
hipEvent_t events[2];
int const kIterations = options.iterations;
for (hipEvent_t &evt : events) {
result = hipEventCreate(&evt);
if (result != hipSuccess) {
std::cerr << "hipEventCreate failed with error " << hipGetErrorString(result) << std::endl;
return false;
}
}
result = hipEventRecord(events[0]);
if (result != hipSuccess) {
std::cerr << "hipEventRecord() failed with error " << hipGetErrorString(result) << std::endl;
return false;
}
for (int iter = 0; iter < kIterations; ++iter) {
status = execute_device_kernel();
if (status != hytlass::Status::kSuccess) {
std::cerr << "Device execution failed." << std::endl;
return false;
}
}
result = hipEventRecord(events[1]);
if (result != hipSuccess) {
std::cerr << "hipEventRecord() failed with error " << hipGetErrorString(result) << std::endl;
return false;
}
result = hipDeviceSynchronize();
if (result != hipSuccess) {
std::cerr << "hipDeviceSynchronize() failed with error " << hipGetErrorString(result) << std::endl;
return false;
}
float elapsed_ms = 0;
result = hipEventElapsedTime(&elapsed_ms, events[0], events[1]);
if (result != hipSuccess) {
std::cerr << "hipEventElapsedTime() failed with error " << hipGetErrorString(result) << std::endl;
return false;
}
for (hipEvent_t &evt : events) {
result = hipEventDestroy(evt);
if (result != hipSuccess) {
std::cerr << "hipEventDestroy() failed with error " << hipGetErrorString(result) << std::endl;
return false;
}
}
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2;
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n();
double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9);
double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30);
double elapsed_ms_per_iter = double(elapsed_ms) / kIterations;
std::cout << " Problem: "
<< options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k()
<< ", batch size: " << options.batch_count
<< std::endl;
std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl;
std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl;
std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl;
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, const char **argv) {
// Options parsing
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (!options.supported()) {
return 0;
}
// Run
Testbed testbed(options);
Disposition disposition = testbed.run();
std::cout << std::endl;
switch (disposition) {
case Disposition::kPassed:
std::cout << "Passed" << std::endl;
break;
case Disposition::kIncorrect:
std::cout << "Incorrect" << std::endl;
break;
case Disposition::kNotVerified:
std::cout << "Not verified" << std::endl;
break;
}
return (disposition == Disposition::kPassed ? 0 : -1);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM kernel to support the epilogue visitor model
for customized softmax partial reduction epilogue fusion.
This source file will likely be moved to `include/hytlass/gemm/kernel/` in the future once
its usage has been stabilized. For now, it is included in this example to demonstrate
some basic output fusion options.
*/
#pragma once
#include "hytlass/hytlass.h"
#include "hytlass/tensor_ref.h"
#include "hytlass/fast_math.h"
#include "hytlass/gemm/gemm.h"
#include "hytlass/matrix_coord.h"
#include "hytlass/complex.h"
#include "hytlass/semaphore.h"
#include "hytlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace hytlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using ElementNorm = typename EpilogueVisitor::ElementNorm;
using ElementSum = typename EpilogueVisitor::ElementSum;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = WARP_SIZE_GPU * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value,
128 / sizeof_bits<ElementB>::value
);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
TensorRefC ref_C;
TensorRefC ref_D;
ElementNorm *ptr_Max;
ElementSum *ptr_Sum;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments():
mode(GemmUniversalMode::kGemm),
batch_count(1)
{ }
/// constructs an arguments structure
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_,
int batch_count_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefC ref_C_,
TensorRefC ref_D_,
ElementNorm *ptr_Max_,
ElementSum *ptr_Sum_,
int64_t batch_stride_A_,
int64_t batch_stride_B_,
typename EpilogueVisitor::Arguments epilogue_visitor_
):
mode(mode_),
problem_size(problem_size_),
batch_count(batch_count_),
ref_A(ref_A_),
ref_B(ref_B_),
ref_C(ref_C_),
ref_D(ref_D_),
ptr_Max(ptr_Max_),
ptr_Sum(ptr_Sum_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
epilogue_visitor(epilogue_visitor_)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
hytlass::gemm::GemmCoord problem_size;
hytlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void * ptr_A;
void * ptr_B;
ElementC * ptr_C;
ElementC * ptr_D;
ElementNorm * ptr_Max;
ElementSum * ptr_Sum;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
HYTLASS_HOST_DEVICE
Params():
swizzle_log_tile(0),
params_A(0),
params_B(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(hytlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
batch_stride_A(0),
batch_stride_B(0)
{ }
Params(
Arguments const &args
):
problem_size(args.problem_size),
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
epilogue_visitor(args.epilogue_visitor)
{
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
HYTLASS_DEVICE
GemmWithEpilogueVisitor() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
hytlass::gemm::GemmCoord const & problem_size) {
HYTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
HYTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
HYTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
HYTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
HYTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
HYTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
hytlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
}
else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
hytlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
hytlass::MatrixCoord tb_offset_B{
offset_k,
threadblock_tile_offset.n() * Mma::Shape::kN
};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
ptr_A,
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
ptr_B,
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = threadIdx.x / WARP_SIZE_GPU;
int lane_idx = threadIdx.x % WARP_SIZE_GPU;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.epilogue.visitor,
params.problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params.params_C,
params.params_D,
params.ptr_C,
params.ptr_D,
params.ptr_Max,
params.ptr_Sum,
threadblock_offset,
blockIdx.y *params.problem_size.m() );
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(
shared_storage.epilogue.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace hytlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <cmath>
#include <iostream>
#include <vector>
#include <limits>
#include "hytlass/hytlass.h"
#include "hytlass/device_kernel.h"
#include "hytlass/arch/memory.h"
#include "hytlass/arch/memory_gfx928.h"
#include "hytlass/gemm/kernel/default_gemm.h"
#include "hytlass/gemm/device/default_gemm_configuration.h"
#include "hytlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
#include "hytlass/epilogue/threadblock/epilogue_with_visitor.h"
#include "hytlass/reduction/kernel/reduce_softmax_final.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "gemm_with_epilogue_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace hytlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Kernel computes partial reduction
//
//
// 2. Sum[m, n'] = sum_n(exp(D[m, n] - N[m, 0]))
//
template <
typename ElementD_,
typename ElementNorm_,
typename ElementSum_,
typename ElementSoft_,
typename ElementSoftmaxCompute_,
int Alignment,
typename ApplyShape_ = MatrixShape<1, 1024>
>
class ApplySoftmax {
public:
using ElementD = ElementD_;
using ElementNorm = ElementNorm_;
using ElementSum = ElementSum_;
using ElementSoft = ElementSoft_;
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
static int const kAlignment = Alignment;
using ApplyShape = ApplyShape_;
using Layout = hytlass::layout::RowMajor;
using TensorRefD = TensorRef<ElementD, Layout>;
using TensorRefN = TensorRef<ElementNorm, Layout>;
using TensorRefSum = TensorRef<ElementSum, Layout>;
using TensorRefSoft = TensorRef<ElementSoft, Layout>;
using FragmentSoftmax = Array<ElementSoftmaxCompute, kAlignment>;
//
// Arguments
//
struct Arguments {
MatrixCoord extent; ///< Extent of D and Softmax matrices
int batch_count; ///< Batch count
TensorRefD ref_D; ///< D matrix computed by GEMM+Max (input)
TensorRefN ref_N; ///< Norm tensor (input)
TensorRefSum ref_S; ///< Sum tensor (input)
TensorRefSoft ref_Soft; ///< Softmax tensor (output)
int64_t batch_stride_D; ///< Batch stride for D tensor
int64_t batch_stride_N; ///< Batch stride for N tensor
int64_t batch_stride_S; ///< Batch stride for S tensor
int64_t batch_stride_Soft; ///< Batch stride for softmax tensor
//
// Methods
//
Arguments():
batch_count(1),
batch_stride_D(0),
batch_stride_N(0),
batch_stride_S(0),
batch_stride_Soft(0)
{ }
Arguments(
MatrixCoord extent_, ///< Extent of D and Softmax matrices
int batch_count_, ///< Batch count
TensorRefD ref_D_, ///< D matrix computed by GEMM+PartialReduce
TensorRefN ref_N_, ///< Output parameter for N
TensorRefSum ref_S_, ///< Output parameter for N
TensorRefSoft ref_Soft_, ///< Softmax
int64_t batch_stride_D_ = 0,
int64_t batch_stride_N_ = 0,
int64_t batch_stride_S_ = 0,
int64_t batch_stride_Soft_ = 0
):
extent(extent_),
batch_count(batch_count_),
ref_D(ref_D_),
ref_N(ref_N_),
ref_S(ref_S_),
ref_Soft(ref_Soft_),
batch_stride_D(batch_stride_D_),
batch_stride_N(batch_stride_N_),
batch_stride_S(batch_stride_S_),
batch_stride_Soft(batch_stride_Soft_)
{
}
};
//
// Params struct
//
struct Params {
Arguments args;
//
// Methods
//
Params() { }
Params(Arguments const &args_): args(args_) { }
};
//
// SharedStorage
//
struct SharedStorage {
};
private:
public:
HYTLASS_DEVICE
ApplySoftmax() { }
HYTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
apply(params, shared_storage);
}
private:
/// Compute Softmax
HYTLASS_DEVICE
void apply(Params const &params, SharedStorage &shared_storage) {
using AccessTypeD = AlignedArray<ElementD, kAlignment>;
int block_batch = blockIdx.z;
int block_m = blockIdx.x * ApplyShape::kRow;
int block_n = 0;
int thread_m = threadIdx.y;
int thread_n = threadIdx.x * kAlignment;
int idx_m = block_m + thread_m;
int idx_n = block_n + thread_n;
int batch_offset_norm = block_batch * params.args.batch_stride_N;
int batch_offset_sum = block_batch * params.args.batch_stride_S;
// Kill off thread if it is outside the row boundary
if (params.args.extent.row() <= idx_m) {
return;
}
//
// Setup pointers to load D again
//
using AccessTypeD = AlignedArray<ElementD, kAlignment>;
using AccessTypeSoft = AlignedArray<ElementSoft, kAlignment>;
using FragmentSoft = Array<ElementSoft, kAlignment>;
using ConvertSoftCompute = hytlass::NumericArrayConverter<ElementSoftmaxCompute, ElementD, kAlignment>;
using ConvertSoftOutput = hytlass::NumericArrayConverter<ElementSoft, ElementSoftmaxCompute, kAlignment>;
using Mul = hytlass::multiplies<FragmentSoftmax>;
using Minus = hytlass::minus<FragmentSoftmax>;
using Exp = hytlass::fast_exp_op<FragmentSoftmax>;
ConvertSoftCompute convert_soft_compute;
ConvertSoftOutput convert_soft_output;
Minus minus;
Mul mul;
Exp exponential;
using ConvertSum = hytlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
using ConvertNorm = hytlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
ConvertSum convert_sum;
ConvertNorm convert_norm;
AccessTypeD *access_d = reinterpret_cast<AccessTypeD *>(
params.args.ref_D.data() +
params.args.batch_stride_D * block_batch +
params.args.ref_D.layout()({idx_m, idx_n}));
AccessTypeSoft *access_soft = reinterpret_cast<AccessTypeSoft *>(
params.args.ref_Soft.data() +
params.args.batch_stride_Soft * block_batch +
params.args.ref_Soft.layout()({idx_m, idx_n}));
ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum];
ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm];
//
// Loop
//
HYTLASS_PRAGMA_UNROLL
for (
int idx = 0;
idx < params.args.extent.column();
idx += ApplyShape::kColumn * kAlignment) {
if (idx_n < params.args.extent.column()) {
AccessTypeD fetch;
arch::global_load<AccessTypeD, sizeof(AccessTypeD)>(fetch, access_d, true);
FragmentSoftmax result = mul(exponential(minus(convert_soft_compute(fetch), convert_norm(norm))), convert_sum(inv_sum));
FragmentSoft soft = convert_soft_output(result);
arch::global_store<FragmentSoft, sizeof(FragmentSoft)>(soft, access_soft, true);
}
access_d += ApplyShape::kColumn;
access_soft += ApplyShape::kColumn;
idx_n += ApplyShape::kColumn * kAlignment;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
/////////////////////////////////////////////////////////////////////////////////////////////////
///
template <
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename ElementCompute_,
typename OperatorClass_,
typename ArchTag_,
typename ThreadblockShape_,
typename WarpShape_,
typename InstructionShape_,
typename EpilogueFunctorOp_,
int kStages_,
typename ApplyShape_ = MatrixShape<1, 1024>,
int AlignmentA_ = 128 / hytlass::sizeof_bits<ElementA_>::value,
int AlignmentB_ = 128 / hytlass::sizeof_bits<ElementB_>::value,
int AlignmentSoftmax_ = 128 / hytlass::sizeof_bits<ElementC_>::value,
typename ElementNorm_ = float,
typename ElementSum_ = float,
typename ElementSoftmax_ = ElementC_
>
class GemmSoftmax {
public:
///////////////////////////////////////////////////////////////////////////////////////////////
//
// Type definitions
//
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementC = ElementC_;
using ElementCompute = ElementCompute_;
using ElementSum = ElementSum_;
using ElementSoft = ElementSoftmax_;
using ElementSoftmaxCompute = float;
using LayoutA = LayoutA_;
using LayoutB = LayoutB_;
using EpilogueFunctorOp = EpilogueFunctorOp_;
using ElementNorm = ElementNorm_;
using ApplyShape = ApplyShape_;
// These are mandatory layouts.
using LayoutC = hytlass::layout::RowMajor;
using LayoutN = hytlass::layout::RowMajor;
using LayoutS = hytlass::layout::RowMajor;
using LayoutSoft = hytlass::layout::RowMajor;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using TensorRefC = TensorRef<ElementC, LayoutC>;
using TensorRefN = TensorRef<ElementNorm, LayoutN>;
using TensorRefSum = TensorRef<ElementSum, LayoutS>;
using TensorRefSoft = TensorRef<ElementSoft, LayoutSoft>;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
static int const kStages = kStages_;
static int const AlignmentA = AlignmentA_;
static int const AlignmentB = AlignmentB_;
static int const AlignmentSoftmax = AlignmentSoftmax_;
using ThreadblockSwizzle = hytlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
///////////////////////////////////////////////////////////////////////////////////////////////
// basic GEMM kernel
using DefaultGemmKernel = typename hytlass::gemm::kernel::DefaultGemm<
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementC,
LayoutC,
ElementCompute,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueFunctorOp,
ThreadblockSwizzle,
kStages,
true,
typename hytlass::gemm::device::DefaultGemmConfiguration<
OperatorClass, ArchTag, ElementA, ElementB, ElementC, ElementCompute>::Operator,
hytlass::gemm::SharedMemoryClearOption::kNone
>::GemmKernel;
///////////////////////////////////////////////////////////////////////////////////////////////
// Epilogue visitor
using EpilogueVisitor = typename hytlass::epilogue::threadblock::EpilogueVisitorSoftmax<
ThreadblockShape,
DefaultGemmKernel::kThreadCount,
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
ElementCompute,
ElementNorm,
ElementSum,
ElementSoftmaxCompute,
EpilogueFunctorOp
>;
/// Epilogue
using Epilogue = typename hytlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
EpilogueVisitor,
typename DefaultGemmKernel::Epilogue
>::Epilogue;
// GEMM
using GemmKernel = gemm::kernel::GemmWithEpilogueVisitor<
typename DefaultGemmKernel::Mma,
Epilogue,
ThreadblockSwizzle
>;
// Softmax kernel
using SoftmaxApplyKernel = kernel::ApplySoftmax<
ElementC,
ElementNorm,
ElementSum,
ElementSoft,
ElementSoftmaxCompute,
AlignmentSoftmax,
ApplyShape
>;
using ApplyFinalReductionKernel = hytlass::reduction::kernel::ApplySoftmaxFinalReduction<
ElementNorm,
ElementSum,
ElementSoftmaxCompute,
ThreadblockShape
>;
public:
/// Arguments class
struct Arguments {
typename GemmKernel::Arguments gemm;
typename SoftmaxApplyKernel::Arguments softmax;
typename ApplyFinalReductionKernel::Arguments reduction;
hytlass::gemm::GemmCoord extend;
//
// Methods
//
Arguments() { }
Arguments(
hytlass::gemm::GemmCoord problem_size,
int32_t batch_count_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueFunctorOp::Params linear_scaling,
TensorRefN ref_N_,
TensorRefSum ref_S_,
TensorRefSoft ref_Softmax_,
int64_t batch_stride_A_ = 0,
int64_t batch_stride_B_ = 0,
int64_t batch_stride_C_ = 0,
int64_t batch_stride_D_ = 0,
int64_t batch_stride_Max_ = 0,
int64_t batch_stride_Sum_ = 0,
int64_t batch_stride_Softmax_ = 0
):
gemm(
hytlass::gemm::GemmUniversalMode::kBatched,
problem_size,
batch_count_,
ref_A_,
ref_B_,
ref_C_,
ref_D_,
ref_N_.data(),
ref_S_.data(),
batch_stride_A_,
batch_stride_B_,
typename EpilogueVisitor::Arguments(
linear_scaling,
batch_stride_C_,
batch_stride_D_,
batch_stride_Max_,
batch_stride_Sum_
)
),
reduction(
problem_size,
ref_N_.data(),
ref_S_.data(),
batch_stride_Max_,
batch_stride_Sum_
),
softmax(
MatrixCoord(problem_size.m(), problem_size.n()),
batch_count_,
ref_D_,
ref_N_,
ref_S_,
ref_Softmax_,
batch_stride_D_,
batch_stride_Max_,
batch_stride_Sum_,
batch_stride_Softmax_
),
extend(problem_size)
{
}
};
struct Params {
typename GemmKernel::Params gemm;
typename SoftmaxApplyKernel::Params softmax;
typename ApplyFinalReductionKernel::Params reduction;
MatrixCoord extend;
//
// Methods
//
Params() { }
Params(Arguments const &args):
gemm(args.gemm),
reduction(args.reduction),
softmax(args.softmax),
extend(MatrixCoord(args.extend.m(), args.extend.n()))
{
}
};
public:
// Gemm
//
// Methods
//
private:
Params params_;
public:
/// Ctor
GemmSoftmax() {
}
/// Initialize
Status initialize(Arguments const &args) {
params_ = Params(args);
return hytlass::Status::kSuccess;
}
/// Run
Status run(hipStream_t stream) {
//
// Launch the GEMM + max kernel
//
dim3 gemm_grid = ThreadblockSwizzle().get_grid_shape(params_.gemm.grid_tiled_shape);
dim3 gemm_block(GemmKernel::kThreadCount, 1, 1);
int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage));
hipError_t result;
if (gemm_smem_size >= (48 << 10)) {
result = hipFuncSetAttribute((const void*)hytlass::Kernel<GemmKernel>,
hipFuncAttributeMaxDynamicSharedMemorySize,
gemm_smem_size);
if (result != hipSuccess) {
return Status::kErrorInternal;
}
}
hytlass::Kernel<GemmKernel><<<gemm_grid, gemm_block, gemm_smem_size, stream>>>(params_.gemm);
result = hipGetLastError();
if (result != hipSuccess) {
return hytlass::Status::kErrorInternal;
}
//
// Launch the ApplyFinalReductionKernel
//
int thread_per_block = 128;
int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
if (block_per_row < 4) {
thread_per_block = 64;
block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block;
}
dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count);
dim3 final_reduction_block(thread_per_block);
Kernel<ApplyFinalReductionKernel><<<
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream
>>>(params_.reduction);
result = hipGetLastError();
if (result != hipSuccess) {
return hytlass::Status::kErrorInternal;
}
//
// Launch the SoftmaxApplyKernel
//
dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow);
int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow;
int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment;
dim3 apply_grid(
(params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows,
(params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns,
params_.softmax.args.batch_count);
Kernel<SoftmaxApplyKernel><<<
apply_grid, apply_block, sizeof(typename SoftmaxApplyKernel::SharedStorage), stream
>>>(params_.softmax);
result = hipGetLastError();
if (result != hipSuccess) {
return hytlass::Status::kErrorInternal;
}
return hytlass::Status::kSuccess;
}
/// Function call operator
Status operator()(hipStream_t stream = nullptr) {
return run(stream);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace hytlass
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
ell_block_sparse_gemm
ell_block_sparse_gemm.cu
)
hytlass_example_add_executable(
ell_block_sparse_gemm_bias_relu
ell_block_sparse_gemm_bias_relu.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Block-Ell sparse gemm example.
This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation.
Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format.
Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here:
https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell
Whereas matrix B is a dense matrix.
Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices.
First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks,
represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix),
represented by tensor_ell_idx in this example, that represent the column indices of the
corresponding non-zero blocks. All rows in the matrices must have the same number of blocks.
ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in
row-major order.
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
for this example:
a_rows - Rows in the sparse matrix.
a_cols - Colums in the sparse matrix.
a_ell_blocksize - Size of the ELL-Blocks.
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)
tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is
(a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize)
tensor_b - Input dense matrix whose size is (a_cols * n)
tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n)
{a_rows, n, a_cols} - Problem size
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/gemm.h"
#include "hytlass/gemm/kernel/gemm_grouped.h"
#include "hytlass/gemm/kernel/default_gemm_grouped.h"
#include "hytlass/gemm/device/ell_gemm.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/distribution.h"
#include "hytlass/util/device_memory.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/host/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/tensor_norm.h"
#include "hytlass/util/host_uncompress.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool reference_check;
int iterations;
int hip_streams;
int a_rows, n, a_cols;
int a_ell_num_columns;
int a_ell_blocksize;
int a_base;
float alpha;
float beta;
//
// Methods
//
Options():
help(false),
reference_check(true),
iterations(20),
hip_streams(0),
a_rows(1024),
n(1024),
a_cols(1024),
a_ell_num_columns(512),
a_ell_blocksize(16),
a_base(0),
alpha(1),
beta()
{ }
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
cmd.get_cmd_line_argument("beta", beta, 0.0f);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("streams", hip_streams, 0);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("a_rows", a_rows, 1024);
cmd.get_cmd_line_argument("n", n, 1024);
cmd.get_cmd_line_argument("a_cols", a_cols, 1024);
cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512);
cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16);
cmd.get_cmd_line_argument("a_base", a_base, 0);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "43_ell_block_sparse_gemm\n\n"
<< " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --a_rows=<int> Sets the number of the rows of the sparse matrix.\n"
<< " --n=<int> Sets the N dimension.\n"
<< " --a_cols=<int> Sets the number of columns of the sparse matrix.\n"
<< " --a_ell_num_columns=<int> Sets the actual number of columns of the Blocked-Ellpack format.\n"
<< " --a_ell_blocksize=<int> Sets the size of the ELL-Block.\n"
<< " --a_base=<int> Sets the base index.\n"
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --reference-check=<bool> If true, performs reference check.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n"
<< "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
class Testbed {
public:
//
// Type definitions
//
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using ElementAccumulator = typename Gemm::ElementAccumulator;
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using MatrixCoord = typename LayoutC::TensorCoord;
private:
//
// Data members
//
Options options;
/// Initialization
hytlass::Distribution::Kind init_A;
hytlass::Distribution::Kind init_B;
hytlass::Distribution::Kind init_C;
hytlass::Distribution::Kind init_ELL;
uint32_t seed;
hytlass::HostTensor<ElementA, LayoutA> tensor_a;
hytlass::HostTensor<ElementB, LayoutB> tensor_b;
hytlass::HostTensor<ElementC, LayoutC> tensor_c;
hytlass::HostTensor<ElementC, LayoutC> tensor_d;
hytlass::HostTensor<ElementA, LayoutA> tensor_a_uncompressed;
hytlass::HostTensor<ElementC, LayoutC> reference_d;
hytlass::HostTensor<int32_t, LayoutA> tensor_ell_idx;
public:
//
// Methods
//
Testbed(
Options const &options_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_ELL_ = hytlass::Distribution::Uniform,
uint32_t seed_ = 3080
):
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { }
private:
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
void initialize_tensor_(
hytlass::TensorView<Element, Layout> view,
hytlass::Distribution::Kind dist_kind,
uint32_t seed) {
if (dist_kind == hytlass::Distribution::Uniform) {
Element scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
int bits_output = hytlass::sizeof_bits<typename Gemm::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
if (hytlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
} else {
scope_max = 8;
scope_min = -8;
}
hytlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == hytlass::Distribution::Gaussian) {
hytlass::reference::host::TensorFillRandomGaussian(
view, seed, Element(), Element(0.5f));
}
else if (dist_kind == hytlass::Distribution::Sequential) {
// Fill with increasing elements
hytlass::reference::host::BlockFillSequential(
view.data(), view.capacity(), Element(1), Element());
} else {
// Fill with all 1s
hytlass::reference::host::BlockFillSequential(
view.data(), view.capacity(), Element(), Element(1));
}
}
/// Initializes data structures
void initialize_() {
tensor_a.resize(hytlass::make_Coord(options.a_rows, options.a_ell_num_columns));
tensor_b.resize(hytlass::make_Coord(options.a_cols, options.n));
tensor_c.resize(hytlass::make_Coord(options.a_rows, options.n));
tensor_d.resize(hytlass::make_Coord(options.a_rows, options.n));
tensor_a_uncompressed.resize(hytlass::make_Coord(options.a_rows, options.a_cols));
reference_d.resize(hytlass::make_Coord(options.a_rows, options.n));
tensor_ell_idx.resize(hytlass::make_Coord(options.a_rows / options.a_ell_blocksize,
options.a_ell_num_columns / options.a_ell_blocksize));
//
// Initialize the problems of the workspace
//
initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021);
initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022);
initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023);
if (init_ELL == hytlass::Distribution::Uniform) {
hytlass::reference::host::TensorFillRandomEllIdx(
tensor_ell_idx.host_view(), seed,
options.a_rows / options.a_ell_blocksize,
options.a_ell_num_columns / options.a_ell_blocksize,
options.a_cols / options.a_ell_blocksize);
} else {
for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) {
for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) {
tensor_ell_idx.at({i, j}) = j+3;
}
}
}
hytlass::reference::host::TensorCopy(reference_d.host_view(), tensor_c.host_view());
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ell_idx.sync_device();
}
/// Verifies the result is a GEMM
bool verify_() {
bool passed = true;
tensor_d.sync_host();
hytlass::uncompress_ell_block_sparse(
tensor_a_uncompressed.host_ref(),
tensor_a.host_ref(),
tensor_ell_idx.host_ref(),
options.a_rows,
options.a_cols,
options.a_ell_num_columns,
options.a_ell_blocksize
);
hytlass::reference::host::Gemm<
typename Gemm::ElementA, typename Gemm::LayoutA,
typename Gemm::ElementB, typename Gemm::LayoutB,
typename Gemm::ElementC, typename Gemm::LayoutC,
ElementCompute,
ElementAccumulator, typename Gemm::Operator>
reference_gemm;
reference_gemm(
{options.a_rows, options.n, options.a_cols},
options.alpha,
tensor_a_uncompressed.host_ref(),
tensor_b.host_ref(),
options.beta,
reference_d.host_ref(),
ElementAccumulator(0)
);
// Reference check
passed = hytlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view());
if (!passed) {
std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl;
std::stringstream fname;
fname << "error_43_ell_block_sparse_gemm"
<< "mnk_"
<< options.a_rows << "x"
<< options.n << "x"
<< options.a_cols << "_"
<< options.a_ell_num_columns << "_"
<< options.a_ell_blocksize << ".txt";
std::cout << fname.str() << std::endl;
std::ofstream results(fname.str());
results
<< "alpha: " << ElementCompute(options.alpha) << "\n"
<< "beta: " << ElementCompute(options.beta) << "\n"
<< "block size: " << options.a_ell_blocksize << "\n"
<< "\nA:\n" << tensor_a.host_view() << "\n"
<< "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n"
<< "\nB:\n" << tensor_b.host_view() << "\n"
<< "\nC:\n" << tensor_c.host_view() << "\n"
<< "\nD reference:\n" << reference_d.host_view() << "\n"
<< "\nD computed:\n" << tensor_d.host_view() << "\n";
return passed;
}
return passed;
}
public:
/// Returns the number of threadblocks to launch if the kernel can run on the target
/// device. Otherwise, returns zero.
bool sufficient() const {
//
// Determine SMEM requirements and waive if not satisfied
//
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
hipDeviceProp_t properties;
int device_idx;
hipError_t result = hipGetDevice(&device_idx);
if (result != hipSuccess) {
throw std::runtime_error("hipGetDevice() API call failed.");
}
result = hipGetDeviceProperties(&properties, device_idx);
if (result != hipSuccess) {
throw std::runtime_error("hipGetDeviceProperties() failed");
}
if (properties.sharedMemPerBlock < smem_size) {
return false;
}
return true;
}
/// Executes a BlockedEll SpMM kernel and measures runtime.
Result profile() {
Result result;
// Early exit
if (!sufficient()) {
std::cout << "Active HIP device lacks hardware resources to run HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
result.passed = false;
// Initialize the problem
initialize_();
// Configure the GEMM arguments
typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
// Configure GEMM arguments
typename Gemm::Arguments args(
{options.a_rows, options.n, options.a_cols},
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
tensor_ell_idx.device_data(),
options.a_ell_num_columns,
options.a_ell_blocksize,
options.a_base,
epilogue_op
);
// Initialize the GEMM object
Gemm gemm{};
result.status = gemm.initialize(args);
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to initialize HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
// Run the BlockedEll SpMM object
result.status = gemm.run();
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to run HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
// Wait for completion
result.error = hipDeviceSynchronize();
if (result.error != hipSuccess) {
std::cerr << "Kernel execution error: " << hipGetErrorString(result.error);
return result;
}
//
// Verify correctness
//
result.passed = true;
if (options.reference_check) {
result.passed = verify_();
}
//
// Warm-up run
//
result.status = gemm.run();
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to run HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMM operations
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
gemm();
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
//
// Cleanup
//
for (auto event : events) {
(void)hipEventDestroy(event);
}
std::cout << std::endl;
std::cout << "ELL Block Sparse GEMM (HYTLASS):\n"
<< "====================================================" << std::endl;
std::cout << std::endl;
std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " " << " GFLOPs: " << result.gflops << std::endl;
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
hipDeviceProp_t props;
hipError_t error = hipGetDeviceProperties(&props, 0);
if (error != hipSuccess) {
std::cerr << "hipGetDeviceProperties() returned an error: " << hipGetErrorString(error) << std::endl;
return -1;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Define the BlockedEll type
//
using ElementA = hytlass::half_t;
using ElementB = hytlass::half_t;
using ElementOutput = hytlass::half_t;
using ElementAccumulator = float;
using LayoutA = hytlass::layout::RowMajor;
using LayoutB = hytlass::layout::RowMajor;
using LayoutC = hytlass::layout::RowMajor;
constexpr int32_t kAlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value;
constexpr int32_t kAlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value;
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>;
constexpr int32_t kStages = 2;
using Gemm = typename hytlass::gemm::device::EllGemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementOutput,
LayoutC,
ElementAccumulator,
hytlass::arch::OpClassTensorOp,
hytlass::arch::Gfx928,
ThreadblockShape,
WarpShape,
InstructionShape,
hytlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / hytlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
kStages, kAlignmentA, kAlignmentB>;
//
// Profile it
//
Testbed<Gemm> testbed(options);
if (!testbed.sufficient()) {
std::cout << "The active HIP device lacks sufficient hardware resources to execute this kernel.\n";
return 0;
}
Result result = testbed.profile();
if (!result.passed) {
std::cout << "Profiling HYTLASS ELL block sparse GEMM has failed.\n";
std::cout << "\nFailed\n";
return -1;
}
std::cout << "\nPassed\n";
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Block-Ell sparse gemm example.
This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation.
Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format.
Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here:
https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell
Whereas matrix B is a dense matrix.
Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices.
First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks,
represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix),
represented by tensor_ell_idx in this example, that represent the column indices of the
corresponding non-zero blocks. All rows in the matrices must have the same number of blocks.
ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in
row-major order.
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format
for this example:
a_rows - Rows in the sparse matrix.
a_cols - Colums in the sparse matrix.
a_ell_blocksize - Size of the ELL-Blocks.
a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns)
tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns)
tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is
(a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize)
tensor_b - Input dense matrix whose size is (a_cols * n)
tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n)
{a_rows, n, a_cols} - Problem size
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/gemm.h"
#include "hytlass/gemm/kernel/gemm_grouped.h"
#include "hytlass/gemm/kernel/default_gemm_grouped.h"
#include "hytlass/gemm/device/ell_gemm.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/distribution.h"
#include "hytlass/util/device_memory.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/host/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/tensor_norm.h"
#include "hytlass/util/host_uncompress.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
// Note this example only works for RowMajor output
// ell_block_sparse_gemm + bias + relu
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool reference_check;
int iterations;
int hip_streams;
int a_rows, n, a_cols;
int a_ell_num_columns;
int a_ell_blocksize;
int a_base;
float alpha;
//
// Methods
//
Options():
help(false),
reference_check(true),
iterations(20),
hip_streams(0),
a_rows(1024),
n(1024),
a_cols(1024),
a_ell_num_columns(512),
a_ell_blocksize(16),
a_base(0),
alpha(1)
{ }
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("streams", hip_streams, 0);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("a_rows", a_rows, 1024);
cmd.get_cmd_line_argument("n", n, 1024);
cmd.get_cmd_line_argument("a_cols", a_cols, 1024);
cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512);
cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16);
cmd.get_cmd_line_argument("a_base", a_base, 0);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "43_ell_block_sparse_gemm\n\n"
<< " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --a_rows=<int> Sets the number of the rows of the sparse matrix.\n"
<< " --n=<int> Sets the N dimension.\n"
<< " --a_cols=<int> Sets the number of columns of the sparse matrix.\n"
<< " --a_ell_num_columns=<int> Sets the actual number of columns of the Blocked-Ellpack format.\n"
<< " --a_ell_blocksize=<int> Sets the size of the ELL-Block.\n"
<< " --a_base=<int> Sets the base index.\n"
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --reference-check=<bool> If true, performs reference check.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n"
<< "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
class Testbed {
public:
//
// Type definitions
//
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using ElementAccumulator = typename Gemm::ElementAccumulator;
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using MatrixCoord = typename LayoutC::TensorCoord;
private:
//
// Data members
//
Options options;
/// Initialization
hytlass::Distribution::Kind init_A;
hytlass::Distribution::Kind init_B;
hytlass::Distribution::Kind init_C;
hytlass::Distribution::Kind init_ELL;
uint32_t seed;
hytlass::HostTensor<ElementA, LayoutA> tensor_a;
hytlass::HostTensor<ElementB, LayoutB> tensor_b;
hytlass::HostTensor<ElementC, LayoutC> tensor_c_bias;
hytlass::HostTensor<ElementC, LayoutC> tensor_d;
hytlass::HostTensor<ElementA, LayoutA> tensor_a_uncompressed;
hytlass::HostTensor<ElementC, LayoutC> reference_d;
hytlass::HostTensor<int32_t, LayoutA> tensor_ell_idx;
public:
//
// Methods
//
Testbed(
Options const &options_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_ELL_ = hytlass::Distribution::Uniform,
uint32_t seed_ = 3080
):
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { }
private:
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
void initialize_tensor_(
hytlass::TensorView<Element, Layout> view,
hytlass::Distribution::Kind dist_kind,
uint32_t seed) {
if (dist_kind == hytlass::Distribution::Uniform) {
Element scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
int bits_output = hytlass::sizeof_bits<typename Gemm::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
if (hytlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
} else {
scope_max = 8;
scope_min = -8;
}
hytlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == hytlass::Distribution::Gaussian) {
hytlass::reference::host::TensorFillRandomGaussian(
view, seed, Element(), Element(0.5f));
}
else if (dist_kind == hytlass::Distribution::Sequential) {
// Fill with increasing elements
hytlass::reference::host::BlockFillSequential(
view.data(), view.capacity(), Element(1), Element());
} else {
// Fill with all 1s
hytlass::reference::host::BlockFillSequential(
view.data(), view.capacity(), Element(), Element(1));
}
}
/// Initializes data structures
void initialize_() {
tensor_a.resize(hytlass::make_Coord(options.a_rows, options.a_ell_num_columns));
tensor_b.resize(hytlass::make_Coord(options.a_cols, options.n));
tensor_c_bias.resize(hytlass::make_Coord(1, options.a_rows));
tensor_d.resize(hytlass::make_Coord(options.a_rows, options.n));
tensor_a_uncompressed.resize(hytlass::make_Coord(options.a_rows, options.a_cols));
reference_d.resize(hytlass::make_Coord(options.a_rows, options.n));
tensor_ell_idx.resize(hytlass::make_Coord(options.a_rows / options.a_ell_blocksize,
options.a_ell_num_columns / options.a_ell_blocksize));
//
// Initialize the problems of the workspace
//
initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021);
initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022);
initialize_tensor_(tensor_c_bias.host_view(), init_C, seed * 2023);
if (init_ELL == hytlass::Distribution::Uniform) {
hytlass::reference::host::TensorFillRandomEllIdx(
tensor_ell_idx.host_view(), seed,
options.a_rows / options.a_ell_blocksize,
options.a_ell_num_columns / options.a_ell_blocksize,
options.a_cols / options.a_ell_blocksize);
} else {
for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) {
for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) {
tensor_ell_idx.at({i, j}) = j+3;
}
}
}
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c_bias.sync_device();
tensor_d.sync_device();
tensor_ell_idx.sync_device();
}
/// Verifies the result is a GEMM
bool verify_() {
bool passed = true;
tensor_d.sync_host();
hytlass::uncompress_ell_block_sparse(
tensor_a_uncompressed.host_ref(),
tensor_a.host_ref(),
tensor_ell_idx.host_ref(),
options.a_rows,
options.a_cols,
options.a_ell_num_columns,
options.a_ell_blocksize
);
hytlass::reference::host::Gemm<
typename Gemm::ElementA, typename Gemm::LayoutA,
typename Gemm::ElementB, typename Gemm::LayoutB,
typename Gemm::ElementC, typename Gemm::LayoutC,
ElementCompute,
ElementAccumulator, typename Gemm::Operator>
reference_gemm;
reference_gemm(
{options.a_rows, options.n, options.a_cols},
options.alpha,
tensor_a_uncompressed.host_ref(),
tensor_b.host_ref(),
0,
reference_d.host_ref(),
ElementAccumulator(0)
);
// Compute bias + relu in host code
for (int i = 0; i < options.a_rows; ++i) {
for (int j = 0; j < options.n; ++j) {
reference_d.at({i, j}) = std::max(
ElementC(0),
ElementC(reference_d.at({i, j}) + tensor_c_bias.at({0, j}))
);
}
}
// Reference check
passed = hytlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view());
if (!passed) {
std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl;
std::stringstream fname;
fname << "error_43_ell_block_sparse_gemm"
<< "mnk_"
<< options.a_rows << "x"
<< options.n << "x"
<< options.a_cols << "_"
<< options.a_ell_num_columns << "_"
<< options.a_ell_blocksize << ".txt";
std::cout << fname.str() << std::endl;
std::ofstream results(fname.str());
results
<< "alpha: " << ElementCompute(options.alpha) << "\n"
<< "block size: " << options.a_ell_blocksize << "\n"
<< "\nA:\n" << tensor_a.host_view() << "\n"
<< "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n"
<< "\nB:\n" << tensor_b.host_view() << "\n"
<< "\nC:\n" << tensor_c_bias.host_view() << "\n"
<< "\nD reference:\n" << reference_d.host_view() << "\n"
<< "\nD computed:\n" << tensor_d.host_view() << "\n";
return passed;
}
return passed;
}
public:
/// Returns the number of threadblocks to launch if the kernel can run on the target
/// device. Otherwise, returns zero.
bool sufficient() const {
//
// Determine SMEM requirements and waive if not satisfied
//
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
hipDeviceProp_t properties;
int device_idx;
hipError_t result = hipGetDevice(&device_idx);
if (result != hipSuccess) {
throw std::runtime_error("hipGetDevice() API call failed.");
}
result = hipGetDeviceProperties(&properties, device_idx);
if (result != hipSuccess) {
throw std::runtime_error("hipGetDeviceProperties() failed");
}
if (properties.sharedMemPerBlock < smem_size) {
return false;
}
return true;
}
/// Executes a BlockedEll SpMM kernel and measures runtime.
Result profile() {
Result result;
// Early exit
if (!sufficient()) {
std::cout << "Active hip device lacks hardware resources to run HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
result.passed = false;
// Initialize the problem
initialize_();
// Configure the GEMM arguments
typename EpilogueOutputOp::Params epilogue_op(options.alpha);
// Configure GEMM arguments
typename Gemm::Arguments args(
{options.a_rows, options.n, options.a_cols},
tensor_a.device_ref(),
tensor_b.device_ref(),
{tensor_c_bias.device_data(), 0},
tensor_d.device_ref(),
tensor_ell_idx.device_data(),
options.a_ell_num_columns,
options.a_ell_blocksize,
options.a_base,
epilogue_op
);
// Initialize the GEMM object
Gemm gemm{};
result.status = gemm.initialize(args);
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to initialize HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
// Run the BlockedEll SpMM object
result.status = gemm.run();
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to run HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
// Wait for completion
result.error = hipDeviceSynchronize();
if (result.error != hipSuccess) {
std::cerr << "Kernel execution error: " << hipGetErrorString(result.error);
return result;
}
//
// Verify correctness
//
result.passed = true;
if (options.reference_check) {
result.passed = verify_();
}
//
// Warm-up run
//
result.status = gemm.run();
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to run HYTLASS BlockedEll SpMM kernel." << std::endl;
return result;
}
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMM operations
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
gemm();
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
//
// Cleanup
//
for (auto event : events) {
(void)hipEventDestroy(event);
}
std::cout << std::endl;
std::cout << "ELL Block Sparse GEMM (HYTLASS):\n"
<< "====================================================" << std::endl;
std::cout << std::endl;
std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " " << " GFLOPs: " << result.gflops << std::endl;
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
hipDeviceProp_t props;
hipError_t error = hipGetDeviceProperties(&props, 0);
if (error != hipSuccess) {
std::cerr << "hipGetDeviceProperties() returned an error: " << hipGetErrorString(error) << std::endl;
return -1;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Define the BlockedEll type
//
using ElementA = hytlass::half_t;
using ElementB = hytlass::half_t;
using ElementOutput = hytlass::half_t;
using ElementAccumulator = float;
using LayoutA = hytlass::layout::RowMajor;
using LayoutB = hytlass::layout::RowMajor;
using LayoutC = hytlass::layout::RowMajor;
constexpr int32_t kAlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value;
constexpr int32_t kAlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value;
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>;
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
//
// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij )
//
using EpilogueOp = hytlass::epilogue::thread::LinearCombinationRelu<
ElementOutput, // <- data type of output matrix
128 / hytlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementAccumulator, // <- data type for alpha in linear combination function
hytlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias
constexpr int32_t kStages = 2;
using Gemm = typename hytlass::gemm::device::EllGemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementOutput,
LayoutC,
ElementAccumulator,
hytlass::arch::OpClassTensorOp,
hytlass::arch::Gfx928,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
kStages, kAlignmentA, kAlignmentB>;
//
// Profile it
//
Testbed<Gemm> testbed(options);
if (!testbed.sufficient()) {
std::cout << "The active device lacks sufficient hardware resources to execute this kernel.\n";
return 0;
}
Result result = testbed.profile();
if (!result.passed) {
std::cout << "Profiling HYTLASS ELL block sparse GEMM has failed.\n";
std::cout << "\nFailed\n";
return -1;
}
std::cout << "\nPassed\n";
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gemm_with_absmax
gemm_with_absmax.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of running an scale AbsMax GEMM.
In addition to using Fp16 Tensor Core instructions, the GEMM uses a distinct epilogue
that enables additional scaling of operands/outputs, storing a pre-activation-function output
tensor (called the "auxiliary" output), and computing the absolute maximum value of the
outputs.
Pseudocode for this epilogue is as follows:
Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias
D = activation(Aux)
if Aux is fp8 type:
abs_max_output = max( abs(aux) | (for every aux in Aux))
Aux = scale_aux * Aux
endif
if D is fp8 type:
abs_max_output = max( abs(d) | (for every d in D))
D = scale_d * D
endif
Parameter Aux is optionally stored to global memory
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/numeric_conversion.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/host/gemm_complex.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/distribution.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_norm.h"
#include "hytlass/util/reference/host/gemm.h"
#include "hytlass/epilogue/thread/activation.h"
#include "hytlass/epilogue/thread/linear_combination_generic_with_scaling.h"
#include "hytlass/gemm/device/gemm_universal_with_absmax.h"
#include "hytlass/layout/matrix.h"
#include "hytlass/matrix_coord.h"
#include "hytlass/gemm/device/gemm_universal_adapter.h"
using ElementA = hytlass::half_t;
using ElementB = hytlass::half_t;
using ElementOutput = hytlass::float_e4m3_t;
using ElementAuxOutput = ElementOutput;
using ElementAccumulator = float;
using LayoutA = hytlass::layout::RowMajor;
using LayoutB = hytlass::layout::ColumnMajor;
using LayoutC = hytlass::layout::RowMajor;
static int const kStages = 1;
static int const kAlignmentA = 8;
static int const kAlignmentB = 8;
using EpilogueOutputOp = hytlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax<
hytlass::epilogue::thread::ReLu,
ElementOutput,
ElementAuxOutput,
8,
ElementAccumulator,
ElementAccumulator
>;
template <typename MathOperator>
using Gemm_ = hytlass::gemm::device::GemmUniversalWithAbsMax<
ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC,
ElementAccumulator, hytlass::arch::OpClassTensorOp, hytlass::arch::Gfx928,
hytlass::gemm::GemmShape<128, 64, 128>, hytlass::gemm::GemmShape<64, 32, 128>, hytlass::gemm::GemmShape<16, 16, 16>,
EpilogueOutputOp, hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages,
kAlignmentA, kAlignmentB, MathOperator
>;
using ElementAbsmax = typename EpilogueOutputOp::ElementAbsmax;
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
hytlass::gemm::GemmCoord problem_size;
int iterations;
int warmup_iterations;
bool scale_A;
bool scale_B;
bool scale_C;
float alpha;
float beta;
Options():
help(false),
error(false),
reference_check(false),
iterations(20),
warmup_iterations(5),
scale_A(true),
scale_B(true),
scale_C(true),
alpha(1.f),
beta(0.f)
{ }
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, 5);
cmd.get_cmd_line_argument("reference-check", reference_check, false);
cmd.get_cmd_line_argument("scale-A", scale_A, true);
cmd.get_cmd_line_argument("scale-B", scale_B, true);
cmd.get_cmd_line_argument("scale-C", scale_C, true);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
int m, n, k;
cmd.get_cmd_line_argument("m", m, 1024);
cmd.get_cmd_line_argument("n", n, 1024);
cmd.get_cmd_line_argument("k", k, 1024);
problem_size = hytlass::gemm::GemmCoord{m, n, k};
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "18_gemm_with_abs_max\n\n"
<< " This example executes a GEMM with AbsMax Tensor Core operations. In addition to performing\n"
<< " a normal GEMM, the kernel performs the following operations:\n"
<< " Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias\n"
<< " D = activation(Aux)\n\n"
<< " if Aux is fp8:\n"
<< " abs_max_output = max( abs(aux) | (for every aux in Aux) )\n"
<< " Aux = scale_aux * Aux\n\n"
<< " if D is fp8 type:\n"
<< " abs_max_output = max( abs(d) | (for every d in D) )\n"
<< " D = scale_d * D\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M dimension of the GEMM\n"
<< " --n=<int> Sets the N dimension of the GEMM\n"
<< " --k=<int> Sets the K dimension of the GEMM\n"
<< " --scale-A=<bool> Whether to apply a scaling factor to operand A (default: true)\n"
<< " --scale-B=<bool> Whether to apply a scaling factor to operand B (default: true)\n"
<< " --scale-C=<bool> Whether to apply a scaling factor to operand C (default: true)\n"
<< " --iterations=<int> Number of profiling iterations to perform\n"
<< " --warmup-iterations=<int> Number of warmup iterations to perform\n"
<< " --reference-check=<bool> If true, performs reference check\n";
return out;
}
/// Compute performance in GFLOP/s
float gflops(float runtime_s) const {
// Two flops per multiply-add
return 2.0f * float(problem_size.product()) / float(1.0e9) / runtime_s;
}
};
/// Helper class to run the kernel
template <typename Gemm>
struct TestbedRunner {
using ElementAccumulator = typename Gemm::ElementAccumulator;
using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute;
using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor;
static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded;
static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded;
/// Initialization
hytlass::Distribution::Kind init_A;
hytlass::Distribution::Kind init_B;
hytlass::Distribution::Kind init_C;
uint64_t seed;
hytlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
hytlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
hytlass::HostTensor<typename Gemm::ElementC, typename Gemm::LayoutC> tensor_C;
hytlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementAuxOutput, typename Gemm::LayoutC> tensor_Aux;
hytlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, typename Gemm::LayoutC> tensor_D;
hytlass::HostTensor<typename Gemm::ElementC, typename Gemm::LayoutC> tensor_Vector;
hytlass::HostTensor<ElementAccumulator, typename Gemm::LayoutC> tmp_D;
hytlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementOutput, typename Gemm::LayoutC> reference_D;
hytlass::HostTensor<typename Gemm::EpilogueOutputOp::ElementAuxOutput, typename Gemm::LayoutC> reference_Aux;
hytlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_A;
hytlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_B;
hytlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_C;
hytlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_D;
hytlass::HostTensor<ElementScalingFactor, typename Gemm::LayoutC> scale_Aux;
hytlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> abs_max_Aux;
hytlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> abs_max_D;
hytlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> reference_abs_max_Aux;
hytlass::HostTensor<ElementAbsmax, typename Gemm::LayoutC> reference_abs_max_D;
//
// Methods
//
TestbedRunner(
bool scaleA = true,
bool scaleB = true,
bool scaleC = true,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize scaling factors
template <typename Element, typename Layout>
bool initialize_scale_factor(hytlass::TensorView<Element, Layout> view, uint64_t seed, int bits=0) {
hytlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits);
return true;
}
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
hytlass::TensorView<Element, Layout> view,
hytlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == hytlass::Distribution::Uniform) {
double scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
int bits_output = hytlass::sizeof_bits<typename Gemm::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
} else {
scope_max = 8;
scope_min = -8;
}
hytlass::reference::host::TensorFillRandomUniform(
view, seed, scope_max, scope_min, 0);
}
else if (dist_kind == hytlass::Distribution::Identity) {
hytlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == hytlass::Distribution::Gaussian) {
hytlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == hytlass::Distribution::Sequential) {
hytlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
std::cerr << "Not implemented";
return false;
}
return true;
}
/// Initializes data structures
void initialize(const Options& options) {
//
// Allocate the GEMM workspace
//
tensor_A.resize(options.problem_size.mk());
tensor_B.resize(options.problem_size.kn());
tensor_C.resize(options.problem_size.mn());
tensor_D.resize(options.problem_size.mn());
tensor_Vector.resize({1, options.problem_size.n()});
reference_D.resize(options.problem_size.mn(), false);
tmp_D.resize(options.problem_size.mn(), false);
initialize_tensor(tensor_A.host_view(), init_A, seed + 2019);
initialize_tensor(tensor_B.host_view(), init_B, seed + 2018);
initialize_tensor(tensor_C.host_view(), init_C, seed + 2017);
initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020);
// It is possible to randomly initialize to all zeros, so override this with non-zeros
// in the upper left corner of each operand.
hytlass::Coord<2> origin(0);
tensor_A.host_view().at(origin) = typename Gemm::ElementA(1);
tensor_B.host_view().at(origin) = typename Gemm::ElementB(1);
tensor_C.host_view().at(origin) = typename Gemm::ElementC(1);
tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1);
hytlass::reference::host::TensorFill(tensor_D.host_view());
hytlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
tensor_Vector.sync_device();
int scale_bits = 2;
if (options.scale_A) {
scale_A.resize({1, 1});
initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits);
scale_A.sync_device();
}
if (options.scale_B) {
scale_B.resize({1, 1});
initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits);
scale_B.sync_device();
}
if (options.scale_C) {
scale_C.resize({1, 1});
initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits);
scale_C.sync_device();
}
if (kScaleOutput) {
scale_D.resize({1, 1});
initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits);
scale_D.sync_device();
abs_max_D.resize({1, 1});
hytlass::reference::host::TensorFill(abs_max_D.host_view());
abs_max_D.sync_device();
reference_abs_max_D.resize({1, 1});
}
if (kScaleAux) {
tensor_Aux.resize(options.problem_size.mn());
hytlass::reference::host::TensorFill(tensor_Aux.host_view());
tensor_Aux.sync_device();
scale_Aux.resize({1, 1});
initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits);
scale_Aux.sync_device();
abs_max_Aux.resize({1, 1});
hytlass::reference::host::TensorFill(abs_max_Aux.host_view());
abs_max_Aux.sync_device();
reference_Aux.resize(options.problem_size.mn(), false);
reference_abs_max_Aux.resize({1, 1});
}
}
/// Compares computed reference with device reference and outputs to a file if incorrect
bool compare_reference(const Options& options) {
tensor_D.sync_host();
bool passed = hytlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
if (kScaleAux) {
tensor_Aux.sync_host();
abs_max_Aux.sync_host();
passed &= hytlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view());
passed &= hytlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view());
}
if (kScaleOutput) {
abs_max_D.sync_host();
passed &= hytlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view());
}
if (!passed) {
std::cerr << "Reference check failed" << std::endl;
std::string output_file = "testbed_with_amax_errors.txt";
std::ofstream file(output_file);
file
<< "problem: " << options.problem_size
<< ", alpha: " << options.alpha << ", beta: " << options.beta << "\n\n";
file
<< "A =\n" << tensor_A.host_view()
<< "\nB =\n" << tensor_B.host_view()
<< "\nC =\n" << tensor_C.host_view()
<< "\nVector =\n" << tensor_Vector.host_view()
<< "\nScaleA = " << scale_A.host_view()
<< "\nScaleB = " << scale_B.host_view()
<< "\nScaleC = " << scale_C.host_view()
<< "\nScaleD = " << scale_D.host_view()
<< "\nScaleAux = " << scale_Aux.host_view()
<< "\n\nReference D =\n" << reference_D.host_view()
<< "\nComputed D =\n" << tensor_D.host_view();
if (kScaleAux) {
file
<< "\n\nReference Aux =\n" << reference_Aux.host_view()
<< "\nComputed Aux =\n" << tensor_Aux.host_view()
<< "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view()
<< "\nComputed Absmax Aux = " << abs_max_Aux.host_view();
}
if (kScaleOutput) {
file
<< "\n\nReference Absmax D = " << reference_abs_max_D.host_view()
<< "\nComputed Absmax D = " << abs_max_D.host_view();
}
std::cerr << "Dumped results to " << output_file << std::endl;
}
return passed;
}
/// Verifies the result is a GEMM
bool verify(const Options& options) {
hytlass::Coord<2> origin(0);
ElementCompute scaled_alpha = options.alpha;
if (options.scale_A) {
scaled_alpha *= scale_A.host_view().at(origin);
}
if (options.scale_B) {
scaled_alpha *= scale_B.host_view().at(origin);
}
ElementCompute scaled_beta = options.beta;
if (options.scale_C) {
scaled_beta *= scale_C.host_view().at(origin);
}
//
// Verify
//
hytlass::reference::host::GemmComplex<
typename Gemm::ElementA, typename Gemm::LayoutA,
typename Gemm::ElementB, typename Gemm::LayoutB,
typename Gemm::ElementC, typename Gemm::LayoutC,
ElementCompute, ElementAccumulator, ElementAccumulator
>(
options.problem_size,
scaled_alpha,
tensor_A.host_ref(),
Gemm::kTransformA,
tensor_B.host_ref(),
Gemm::kTransformB,
scaled_beta,
tensor_C.host_ref(),
tmp_D.host_ref(),
ElementAccumulator(0)
);
ElementCompute tmp_abs_max_Aux(0.);
ElementCompute tmp_abs_max_D(0.);
hytlass::NumericConverter<ElementCompute, typename Gemm::ElementC> cvt_c_to_compute;
hytlass::NumericConverter<ElementCompute, ElementAccumulator> cvt_accum_to_compute;
hytlass::NumericConverter<ElementAccumulator, ElementCompute> cvt_compute_to_accum;
hytlass::NumericConverter<typename Gemm::EpilogueOutputOp::ElementOutput, ElementCompute> cvt_compute_to_d;
hytlass::NumericConverter<typename Gemm::EpilogueOutputOp::ElementAuxOutput, ElementCompute> cvt_compute_to_aux;
hytlass::absolute_value_op<ElementCompute> abs;
hytlass::maximum_with_nan_propogation<ElementCompute> max;
hytlass::epilogue::thread::ReLu<ElementCompute> act;
ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.);
for (int m = 0; m < options.problem_size.m(); ++m) {
for (int n = 0; n < options.problem_size.n(); ++n) {
ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n}));
ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n}));
ElementCompute aux = intermediate + bias;
ElementCompute d = act(aux);
tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux);
tmp_abs_max_D = max(abs(d), tmp_abs_max_D);
reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale);
if (kScaleAux) {
reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin));
}
}
}
if (kScaleAux) {
reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_Aux);
}
if (kScaleOutput) {
reference_abs_max_D.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_D);
}
return compare_reference(options);
}
/// Returns true if the hip device is sufficient to execute the kernel.
bool sufficient() const {
size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage);
hipDeviceProp_t properties;
int device_idx;
hipError_t result = hipGetDevice(&device_idx);
if (result != hipSuccess) {
std::cerr << "hipGetDevice() failed with error: " << hipGetErrorString(result) << std::endl;
return false;
}
result = hipGetDeviceProperties(&properties, device_idx);
if (result != hipSuccess) {
std::cerr << "hipGetDeviceProperties() failed with error: " << hipGetErrorString(result) << std::endl;
return false;
}
return true;
}
/// Executes one test
bool run(Options& options)
{
// Waive test if insufficient hip device
if (!sufficient()) {
std::cerr << "Insufficient resources to run the kernel." << std::endl;
return false;
}
this->initialize(options);
//
// Initialize the GEMM operator
//
typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{
ElementCompute(options.alpha),
ElementCompute(options.beta)
};
typename Gemm::EpilogueOutputOp::Params epilogue_params{
activation_params,
scale_A.device_data(),
scale_B.device_data(),
scale_C.device_data(),
scale_D.device_data(),
scale_Aux.device_data(),
abs_max_Aux.device_data(),
abs_max_D.device_data()
};
typename Gemm::Arguments arguments{
hytlass::gemm::GemmUniversalMode::kGemm,
options.problem_size,
/* batch_count = */ 1,
epilogue_params,
tensor_A.device_data(),
tensor_B.device_data(),
tensor_C.device_data(),
tensor_D.device_data(),
tensor_Aux.device_data(),
tensor_Vector.device_data(),
options.problem_size.m() * options.problem_size.k(),
options.problem_size.n() * options.problem_size.k(),
options.problem_size.m() * options.problem_size.n(),
options.problem_size.m() * options.problem_size.n(),
(int)options.problem_size.m(), // Batch stride vector
tensor_A.layout().stride(0),
tensor_B.layout().stride(0),
tensor_C.layout().stride(0),
tensor_D.layout().stride(0),
(int64_t)0 // Leading dimension of vector. This must be 0
};
Gemm gemm_op;
hytlass::Status status = gemm_op.can_implement(arguments);
if (status != hytlass::Status::kSuccess) {
std::cerr << "Gemm::can_implement() failed" << std::endl;
return false;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
status = gemm_op.initialize(arguments, workspace.get());
if (status != hytlass::Status::kSuccess) {
std::cerr << "Gemm::initialize() failed" << std::endl;
return false;
}
//
// Run the GEMM
//
status = gemm_op();
if (status != hytlass::Status::kSuccess) {
std::cerr << "Gemm::run() failed" << std::endl;
return false;
}
hipError_t hip_error = hipDeviceSynchronize();
if (hip_error != hipSuccess) {
std::cerr << "hip error: " << hipGetErrorString(hip_error) << std::endl;
return false;
}
//
// Verify
//
bool passed = true;
if (options.reference_check) {
passed &= this->verify(options);
} else {
std::cout << "Skipped reference check" << std::endl;
}
//
// Warm up
//
for (int i = 0; i < options.warmup_iterations; ++i) {
gemm_op();
}
//
// Profile
//
hipEvent_t events[2];
hipError_t error;
for (auto & event : events) {
error = hipEventCreate(&event);
if (error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(error) << std::endl;
return false;
}
}
// Record an event at the start of a series of GEMM operations
error = hipEventRecord(events[0]);
if (error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(error) << std::endl;
return false;
}
// Run profiling loop
for (int iter = 0; iter < options.iterations; ++iter) {
gemm_op();
}
// Record an event when the GEMM operations have been launched.
error = hipEventRecord(events[1]);
if (error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(error) << std::endl;
return false;
}
// Wait for work on the device to complete.
error = hipEventSynchronize(events[1]);
if (error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(error) << std::endl;
return false;
}
// Measure elapsed runtime
float runtime_ms = 0;
error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(error) << std::endl;
return false;
}
// Compute average runtime and GFLOPs.
runtime_ms = runtime_ms / float(options.iterations);
float gflops = options.gflops(runtime_ms / 1000.0f);
std::cout << "Problem size: " << options.problem_size.m() << 'x' << options.problem_size.n() << 'x' << options.problem_size.k() << std::endl;
std::cout << "Runtime (ms): " << runtime_ms << std::endl;
std::cout << "GFLOPs/sec: " << gflops << std::endl;
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
return passed;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const** argv) {
hipDeviceProp_t props;
hipError_t error = hipGetDeviceProperties(&props, 0);
if (error != hipSuccess) {
std::cerr << "hipGetDeviceProperties() returned an error: " << hipGetErrorString(error) << std::endl;
return -1;
}
//
// Parse options
//
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
std::cout << "Running GEMM with staged accumulation (OpMultiplyAdd)" << std::endl;
std::cout << "=====================================================" << std::endl;
TestbedRunner<Gemm_<hytlass::arch::OpMultiplyAdd>> testbed_staged_accum;
bool passed = testbed_staged_accum.run(options);
if (passed) {
std::cout << "Passed" << std::endl;
} else {
std::cout << "Failed" << std::endl;
}
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment