Commit d22dbec2 authored by zhoux's avatar zhoux
Browse files

Initial commit: release hytlass-0.1.0

parents
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/***************************************************************************************************
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
"classic data-parallel" and "Split-K" decompositions + residual add.
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
**************************************************************************************************/
#include <iostream>
#include <string>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm_universal.h"
#include "hytlass/gemm/device/gemm_universal_with_broadcast.h"
#include "hytlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "hytlass/epilogue/thread/linear_combination_residual_block.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/error_metrics.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_foreach.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/epilogue/threadblock/fusion/visitors.hpp"
#include "hytlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "hytlass/gemm/device/gemm_universal_adapter.h"
#include "hytlass/gemm/device/gemm_universal_base.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations (hytlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = hytlass::half_t; // Element type for A matrix operand
using LayoutA = hytlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = hytlass::half_t; // Element type for B matrix operand
using LayoutB = hytlass::layout::RowMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C1/C2/D matrix configuration
using ElementC = hytlass::half_t; // Element type for C matrix operands
using LayoutC = hytlass::layout::RowMajor; // Layout type for C matrix operands
constexpr int AlignmentC = 128 / hytlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = hytlass::half_t; // Element type for output matrix operands
using LayoutOutput = hytlass::layout::RowMajor; // Layout type for output matrix operands
// constexpr int AlignmentOutput = 128 / hytlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for compute
using ArchTag = hytlass::arch::Gfx928; // Tag indicating the minimum Gfx that supports the intended feature
using OperatorClass = hytlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 1; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT
// Residual block configuration
// Epilogue output operator
/// Using LinearCombinationResidualBlock
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
using EpilogueOp = hytlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, // Element type for output matrix
ElementAccumulator, // Element type from internal accumulation
ElementCompute, // Element type from internal accumulation
ElementC, // Element type for C1/C2/D matrix operands
AlignmentC, // Memory access granularity of C and D matrix in units of elements
hytlass::epilogue::thread::Identity, // Activation
hytlass::plus, // Binary operation 1
hytlass::epilogue::thread::Identity, // Unary operation
hytlass::plus // Binary operation 2
>;
// Reference device GEMM implementation type
using DeviceGemmReference = hytlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
// Classic data-parallel device GEMM implementation type
using DeviceGemmBasic = hytlass::gemm::device::GemmUniversalWithBroadcast<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
NumStages,
AlignmentA,
AlignmentB>;
// StreamK device GEMM implementation type with EVT
using namespace hute;
using OutputTileThreadMap = hytlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
EVTEpilogueStages
>;
using Accum = hytlass::epilogue::threadblock::VisitorAccFetch;
// 数据读取
using Bias = hytlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
hute::Stride<_0, _1, int32_t> // StrideMNL
>;
using C1 = hytlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, ElementC,
hute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using C2 = hytlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, ElementC,
hute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using Compute0 = hytlass::epilogue::threadblock::VisitorCompute<
hytlass::plus, ElementCompute, ElementCompute,
hytlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute0 = hytlass::epilogue::threadblock::Sm80EVT<
Compute0,
Accum,
Bias>;
using Compute1 = hytlass::epilogue::threadblock::VisitorCompute<
hytlass::plus, ElementCompute, ElementCompute,
hytlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute1 = hytlass::epilogue::threadblock::Sm80EVT<
Compute1,
EVTCompute0,
C1>;
using Compute2 = hytlass::epilogue::threadblock::VisitorCompute<
hytlass::plus, ElementOutput, ElementCompute,
hytlass::FloatRoundStyle::round_to_nearest
>;
using EVTCompute2 = hytlass::epilogue::threadblock::Sm80EVT<
Compute2,
EVTCompute1,
C2>;
using D = hytlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, hytlass::FloatRoundStyle::round_to_nearest,
hute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using EVTD = hytlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute2>;
using EVTKernelStreamK =
typename hytlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, hytlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, hytlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
hytlass::arch::OpClassTensorOp,
hytlass::arch::Gfx928,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
hytlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
hytlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel;
using DeviceGemmStreamK = hytlass::gemm::device::GemmUniversalBase<EVTKernelStreamK>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/// Command line options parsing
struct Options {
std::string command_name;
bool help;
hytlass::gemm::GemmCoord problem_size;
float alpha;
float beta;
int split_k_factor;
int avail_sms;
int iterations;
bool real;
int matrix_value;
hytlass::HostTensor<ElementA, LayoutA> tensor_a;
hytlass::HostTensor<ElementB, LayoutB> tensor_b;
hytlass::HostTensor<ElementC, LayoutC> tensor_c1;
hytlass::HostTensor<ElementC, LayoutC> tensor_c2;
hytlass::HostTensor<ElementC, LayoutC> tensor_d;
hytlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
hytlass::HostTensor<ElementC, LayoutC> tensor_Vector;
// hytlass::HostTensor<ElementC, LayoutC> tensor_Tensor;
Options(std::string command_name) :
command_name(command_name),
help(false),
problem_size({2048, 2048, 2048}),
alpha(1.0f),
beta(1.0f),
split_k_factor(1),
avail_sms(-1), // Number of device SMs to use is unlimited
real(false),
iterations(100)
{}
bool valid() const {
return true;
}
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("split", split_k_factor);
cmd.get_cmd_line_argument("iterations", iterations);
real = cmd.check_cmd_line_flag("real");
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out
<< "Performs a GEMM computation.\n"
<< "\n"
<< "Options:\n"
<< "\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --split=<int> Split-K factor to emulate\n\n"
<< " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
typename DeviceGemmBasic::Arguments args_from_options(
const DeviceGemmBasic &device_gemm,
const Options &options,
hytlass::HostTensor<ElementA, LayoutA> &tensor_a,
hytlass::HostTensor<ElementB, LayoutB> &tensor_b,
hytlass::HostTensor<ElementC, LayoutC> &tensor_c1,
hytlass::HostTensor<ElementC, LayoutC> &tensor_c2,
hytlass::HostTensor<ElementC, LayoutC> &tensor_d,
hytlass::HostTensor<ElementC, LayoutC> &tensor_Vector /*,
hytlass::HostTensor<ElementC, LayoutC> &tensor_Tensor */
)
{
return typename DeviceGemmBasic::Arguments(
hytlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
{ // epilogue parameters
ElementAccumulator(options.alpha),
ElementAccumulator(options.beta)
},
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c1.device_data(), // ptr_C1
tensor_c2.device_data(), // ptr_C2
tensor_d.device_data(), // ptr_D
tensor_Vector.device_data(), // ptr_Vector
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
options.problem_size.mn().product(), // batch_stride_C1
options.problem_size.mn().product(), // batch_stride_C2
options.problem_size.mn().product(), // batch_stride_D
options.problem_size.mn().product(), // batch_stride_Vector
options.problem_size.mn().product(), // batch_stride_Tensor
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c1.layout().stride(0), // stride_c1
tensor_c2.layout().stride(0), // stride_c2
tensor_d.layout().stride(0), // stride_d
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector
/*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor
}
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
typename DeviceGemmStreamK::Arguments args_from_options(
const DeviceGemmStreamK &device_gemm,
const Options &options,
hytlass::HostTensor<ElementA, LayoutA> &tensor_a,
hytlass::HostTensor<ElementB, LayoutB> &tensor_b,
hytlass::HostTensor<ElementC, LayoutC> &tensor_c1,
hytlass::HostTensor<ElementC, LayoutC> &tensor_c2,
hytlass::HostTensor<ElementC, LayoutC> &tensor_d,
hytlass::HostTensor<ElementC, LayoutC> &tensor_Vector/*,
hytlass::HostTensor<ElementC, LayoutC> &tensor_Tensor*/
)
{
typename EVTD::Arguments callback_args{
{
{
{
{}, // Accum
{tensor_Vector.device_data(), ElementC(0), {_0{}, _1{}, int32_t(options.problem_size.n())}}, // Bias
{} // Compute0
}, // EVTCompute0
{tensor_c1.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C1
{} // Compute1
}, // EVTCompute1
{tensor_c2.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C2
{} // Compute2
}, // EVTCompute2
{tensor_d.device_data(), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // D
}; // EVTD
return typename DeviceGemmStreamK::Arguments(
hytlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
0, // stride_c (unused)
0, // stride_d (unused)
options.avail_sms); // avail_sms
}
/// Execute a given example GEMM computation
template <typename DeviceGemmT>
Result run(std::string description, Options &options) {
// Display test description
std::cout << std::endl << description << std::endl;
// Zero-initialize test output matrix D
hytlass::reference::host::TensorFill(options.tensor_d.host_view());
options.tensor_d.sync_device();
// Instantiate HYTLASS kernel depending on templates
DeviceGemmT device_gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
auto arguments = args_from_options(device_gemm, options,
options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d,
options.tensor_Vector/*, options.tensor_Tensor*/);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
HYTLASS_CHECK(device_gemm.can_implement(arguments));
// Initialize HYTLASS kernel with arguments and workspace pointer
HYTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
HYTLASS_CHECK(device_gemm());
// Copy output data from HYTLASS and reference kernel to host for comparison
options.tensor_d.sync_host();
// Check if output from HYTLASS kernel and reference kernel are equal or not
Result result;
ElementC eps(1e-3);
ElementC non_zero_floor(1e-6);
// Reference check
result.passed = hytlass::reference::host::TensorRelativelyEquals(options.tensor_d.host_view(),
options.tensor_ref_d.host_view(), eps, non_zero_floor);
double err = hytlass::reference::host::TensorRelativeErrorMetric(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
HYTLASS_CHECK(device_gemm());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
// TODO: uncomment when results match
//if (!result.passed) {
// exit(-1);
//}
return result;
}
/// Program entrypoint
int main(int argc, const char **argv) {
// Parse commandline options
Options options("gfx928_streamk_broadcast_gemm");
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
std::cout <<
options.iterations << " timing iterations of " <<
options.problem_size.m() << " x " <<
options.problem_size.n() << " x " <<
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
//
// Initialize GEMM datasets
//
// Initialize tensors using HYTLASS helper functions
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N
options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from HYTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1
// options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N
int _init_bits = options.real ? -1 : 0;
// Fill matrix A on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2), _init_bits);
// Fill matrix B on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2), _init_bits);
// Fill matrix C1 on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_c1.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
// Fill matrix C2 on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_c2.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_Vector.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
//
// Compute reference output
//
// Copy data from host to GPU
options.tensor_a.sync_device();
options.tensor_b.sync_device();
options.tensor_c1.sync_device();
options.tensor_c2.sync_device();
options.tensor_Vector.sync_device();
// options.tensor_Tensor.sync_device();
// Zero-initialize reference output matrix D
hytlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
options.tensor_ref_d.sync_device();
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
options.problem_size,
ElementAccumulator(options.alpha),
options.tensor_a.device_ref(),
options.tensor_b.device_ref(),
ElementAccumulator(options.beta),
options.tensor_c1.device_ref(),
options.tensor_ref_d.device_ref());
// Wait for kernels to finish
HIP_CHECK(hipDeviceSynchronize());
// Copy output data from reference kernel to host for comparison
options.tensor_ref_d.sync_host();
// Add broadcast vector (without multiplier)
// This is only possible because BinaryOp is addition, and UnaryOps are identity.
// This makes the addition of broadcast vector commutable.
/// identity(plus(identity(alpha * (a * b) + v), beta * c)) ==
/// alpha * a * b + v + beta * c ==
/// (alpha * a * b + beta * c) + v ==
/// GEMM(a, b, c) + v
// Vector broadcast on host
for (int i=0; i < options.problem_size.m(); ++i) {
for (int j=0; j < options.problem_size.n(); ++j) {
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j});
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j});
}
}
// Sync back with device just in case
options.tensor_ref_d.sync_device();
//
// Evaluate HYTLASS kernels
//
// Test default operation
if (options.split_k_factor == 1) {
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
// printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
// // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
// // printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
// options.split_k_factor++; // Increment splitting factor for next evaluation
}
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
// Result basic_splitk = run<DeviceGemmBasic>(
// std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
// options);
Result streamk_splitk = run<DeviceGemmStreamK>(
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
// printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
return 0;
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/***************************************************************************************************
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
"classic data-parallel" and "Split-K" decompositions.
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
Requires HYGON gfx928 or newer device (Gfx928+).
- Build and run:
hytlass$ mkdir build
hytlass$ cd build
hytlass/build$ cmake .. -DHYTLASS_HIPCC_ARCHS=928
hytlass/build$ make gfx928_streamk_gemm
hytlass/build$ ./examples/03_hytlass_streamk_gemm/gfx928_streamk_gemm
10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply
Basic data-parallel GEMM
Disposition: Passed
Avg runtime: 0.112633 ms
GFLOPs: 152530
StreamK GEMM with default load-balancing
Disposition: Passed
Avg runtime: 0.0941929 ms
GFLOPs: 182390
Speedup vs Basic-DP: 1.196
StreamK emulating basic data-parallel GEMM
Disposition: Passed
Avg runtime: 0.113119 ms
GFLOPs: 151875
Speedup vs Basic-DP: 0.996
Basic split-K GEMM with tile-splitting factor 2
Disposition: Passed
Avg runtime: 0.104772 ms
GFLOPs: 163973
StreamK emulating Split-K GEMM with tile-splitting factor 2
Disposition: Passed
Avg runtime: 0.105379 ms
GFLOPs: 163029
Speedup vs Basic-SplitK: 0.994
**************************************************************************************************/
#include <iostream>
#include <string>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm_universal.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations (hytlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = hytlass::half_t; // Element type for A matrix operand
using LayoutA = hytlass::layout::ColumnMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = hytlass::half_t; // Element type for B matrix operand
using LayoutB = hytlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = hytlass::half_t; // Element type for C and D matrix operands
using LayoutC = hytlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / hytlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = hytlass::arch::Gfx928; // Tag indicating the minimum Gfx that supports the intended feature
using OperatorClass = hytlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 1; // Number of global->shared pipeline stages used in the GEMM mainloop
// Epilogue output operator
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementC, // Element type for C and D matrix operands
AlignmentC, // Memory access granularity of C and D matrix in units of elements
ElementAccumulator, // Element type from internal accumaccumulation
ElementAccumulator>; // Data type used to compute linear combination
// Reference device GEMM implementation type
using DeviceGemmReference = hytlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
// StreamK device GEMM implementation type
using DeviceGemmStreamK = hytlass::gemm::device::GemmUniversal<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
hytlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
NumStages,
AlignmentA,
AlignmentB>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double avg_runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/// Command line options parsing
struct Options {
std::string command_name;
bool help;
hytlass::gemm::GemmCoord problem_size;
float alpha;
float beta;
int split_k_factor;
int avail_sms;
bool reference_check;
int iterations;
hytlass::HostTensor<ElementA, LayoutA> tensor_a;
hytlass::HostTensor<ElementB, LayoutB> tensor_b;
hytlass::HostTensor<ElementC, LayoutC> tensor_c;
hytlass::HostTensor<ElementC, LayoutC> tensor_d;
hytlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
Options(std::string command_name) :
command_name(command_name),
help(false),
problem_size({2048, 2048, 2048}),
alpha(1.0f),
beta(0.0f),
split_k_factor(1),
avail_sms(-1), // Number of device SMs to use is unlimited
reference_check(true),
iterations(100)
{}
bool valid() const {
return true;
}
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("split", split_k_factor);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "03_hytlass_streamk_gemm example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --split=<int> Split-K factor to emulate\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/03_hytlass_streamk_gemm/gfx928_streamk_gemm --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Two flops per multiply-add
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
typename DeviceGemmStreamK::Arguments args_from_options(
const DeviceGemmStreamK &device_gemm,
const Options &options,
hytlass::HostTensor<ElementA, LayoutA> &tensor_a,
hytlass::HostTensor<ElementB, LayoutB> &tensor_b,
hytlass::HostTensor<ElementC, LayoutC> &tensor_c,
hytlass::HostTensor<ElementC, LayoutC> &tensor_d)
{
return typename DeviceGemmStreamK::Arguments(
hytlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
{ // epilogue parameters
ElementAccumulator(options.alpha),
ElementAccumulator(options.beta)
},
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c.device_data(), // ptr_C
tensor_d.device_data(), // ptr_D
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
options.problem_size.mn().product(), // batch_stride_C
options.problem_size.mn().product(), // batch_stride_D
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c.layout().stride(0), // stride_c
tensor_d.layout().stride(0), // stride_d
options.avail_sms); // avail_sms
}
/// Execute a given example GEMM computation
template <typename DeviceGemmT>
Result run(std::string description, Options &options) {
// Display test description
std::cout << std::endl << description << std::endl;
// Zero-initialize test output matrix D
hytlass::reference::host::TensorFill(options.tensor_d.host_view());
options.tensor_d.sync_device();
// Instantiate HYTLASS kernel depending on templates
DeviceGemmT device_gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
HYTLASS_CHECK(device_gemm.can_implement(arguments));
// Initialize HYTLASS kernel with arguments and workspace pointer
HYTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
HYTLASS_CHECK(device_gemm());
// Copy output data from HYTLASS and reference kernel to host for comparison
options.tensor_d.sync_host();
// Check if output from HYTLASS kernel and reference kernel are equal or not
Result result;
ElementC eps(1e-3);
ElementC non_zero_floor(1e-6);
// Reference check
result.passed = hytlass::reference::host::TensorRelativelyEquals(options.tensor_d.host_view(),
options.tensor_ref_d.host_view(), eps, non_zero_floor);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
// Run profiling loop
if (options.iterations > 0) {
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
HYTLASS_CHECK(device_gemm());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
if (!result.passed) {
printf("verify failed");
}
return result;
}
/// Program entrypoint
int main(int argc, const char **argv) {
// Parse commandline options
Options options("gfx928_streamk_gemm");
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
std::cout <<
options.iterations << " timing iterations of " <<
options.problem_size.m() << " x " <<
options.problem_size.n() << " x " <<
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
//
// Initialize GEMM datasets
//
// Initialize tensors using HYTLASS helper functions
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from HYTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
// Fill matrix A on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2),
0);
// Fill matrix B on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2),
0);
// Fill matrix C on host with uniform-random data [-2, 2]
hytlass::reference::host::TensorFillRandomUniform(
options.tensor_c.host_view(),
1,
ElementC(2),
ElementC(-2),
0);
//
// Compute reference output
//
// Copy data from host to GPU
options.tensor_a.sync_device();
options.tensor_b.sync_device();
options.tensor_c.sync_device();
// Zero-initialize reference output matrix D
hytlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
options.tensor_ref_d.sync_device();
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
options.problem_size,
ElementAccumulator(options.alpha),
options.tensor_a.device_ref(),
options.tensor_b.device_ref(),
ElementAccumulator(options.beta),
options.tensor_c.device_ref(),
options.tensor_ref_d.device_ref());
// Wait for kernels to finish
HIP_CHECK(hipDeviceSynchronize());
// Copy output data from reference kernel to host for comparison
options.tensor_ref_d.sync_host();
//
// Evaluate HYTLASS kernels
//
// Test default operation
if (options.split_k_factor == 1) {
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
options.split_k_factor++; // Increment splitting factor for next evaluation
}
Result streamk_splitk = run<DeviceGemmStreamK>(
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
return 0;
}
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_batch_gemm
gfx928_batch_gemm.cu
)
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include <vector>
#include "hytlass/util/command_line.h"
#include "hip/hip_runtime.h"
#include "hytlass/hytlass.h"
#include "hytlass/layout/matrix.h"
#include "hytlass/gemm/device/gemm_batched.h"
#pragma warning( disable : 4503)
/*
This example demonstrates how to use hytlass to compute a batched strided gemm in two different ways:
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
matrices of the batch (this is called a strided batched gemm).
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
In this example, both A and B matrix are non-transpose and column major matrix
batched_C = batched_A x batched_B
As an example, matrix C can be seen as
-----------------------------------------------------------
(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
-----------------------------------------------------------
(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------------------------------------
(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
-----------------------------------------------------------
(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
-----------------------------------------------------------
(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
-----------------------------------------------------------
(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
-----------------------------------------------------------
batch 0 | batch 1
where we denote each element with (batch_idx, row_idx, column_idx)
In this example, batch size is 2, M is 6 and N is 3
The stride (batch_stride_C) between the first element of two batches is ldc * n
matrix A can be seen as
---------------------------------------
(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
---------------------------------------
(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
---------------------------------------
(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
---------------------------------------
(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
---------------------------------------
(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
---------------------------------------
(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
---------------------------------------
batch 0 | batch 1
, where batch size is 2, M is 6 and K is 2
The stride (batch_stride_A) between the first element of two batches is lda * k
matrix B can be seen as
-----------------------------
(0,0,0) | (0,0,1) | (0,0,2) |
----------------------------- batch 0
(0,1,0) | (0,1,1) | (0,1,2) |
-------------------------------------
(1,0,0) | (1,0,1) | (1,0,2) |
----------------------------- batch 1
(1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------
, where the batch size is 2, N is 3 and K is 2
The stride (batch_stride_B) between the first element of two batches is k
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
int batch_count;
float alpha;
float beta;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({1024, 1024, 1024}),
batch_count(2),
reference_check(false),
alpha(1),
beta()
{}
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("batch_count", batch_count);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "04_hytlass_batch_gemm example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --batch_count=<int> Batch number\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/04_hytlass_batch_gemm/gfx928_batch_gemm --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 --beta=0.707 --batch_count=2 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product() * batch_count;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
hipError_t hytlass_strided_batched_sgemm(
int m,
int n,
int k,
float alpha,
float const *A,
int lda,
long long int batch_stride_A,
float const *B,
int ldb,
long long int batch_stride_B,
float *C,
int ldc,
long long int batch_stride_C,
float beta,
int batch_count) {
using Gemm = hytlass::gemm::device::GemmBatched<
float, hytlass::layout::ColumnMajor, float, hytlass::layout::ColumnMajor,
float, hytlass::layout::ColumnMajor, float,
hytlass::arch::OpClassTensorOp, hytlass::arch::Gfx928,
hytlass::gemm::GemmShape<128, 128, 32>,
hytlass::gemm::GemmShape<64, 64, 32>, hytlass::gemm::GemmShape<16, 16, 8>,
hytlass::epilogue::thread::LinearCombination<float, 1, float, float>,
hytlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 1, 4,
4, hytlass::arch::OpMultiplyAdd
>;
Gemm gemm_op;
hytlass::Status status = gemm_op({
{m, n, k},
{A, lda},
batch_stride_A,
{B, ldb},
batch_stride_B,
{C, ldc},
batch_stride_C,
{C, ldc},
batch_stride_C,
{alpha, beta},
batch_count
});
if (status != hytlass::Status::kSuccess) {
return hipErrorUnknown;
}
return hipSuccess;
}
template<typename T>
hipError_t strided_batched_gemm_nn_reference(
int m,
int n,
int k,
T alpha,
std::vector<T> const &A,
int lda,
long long int batch_stride_A,
std::vector<T> const &B,
int ldb,
long long int batch_stride_B,
std::vector<T> &C,
int ldc,
long long int batch_stride_C,
T beta,
int batch_count) {
/*
strided batched gemm NN
*/
hipError_t result = hipSuccess;
if (A.size() < size_t(lda * k * batch_count)) {
std::cout << "the size of A is too small" << std::endl;
return hipErrorInvalidValue;
}
if (B.size() < size_t(ldb * n)) {
std::cout << "the size of B is too small" << std::endl;
return hipErrorInvalidValue;
}
if (C.size() < size_t(ldc * n * batch_count)) {
std::cout << "the size of C is too small" << std::endl;
return hipErrorInvalidValue;
}
for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) {
for (int n_idx = 0; n_idx < n; n_idx++) {
for (int m_idx = 0; m_idx < m; m_idx++) {
T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx];
for (int k_idx = 0; k_idx < k; k_idx++) {
accum += alpha
* A[batch_idx * batch_stride_A + k_idx * lda + m_idx]
* B[batch_idx * batch_stride_B + n_idx * ldb + k_idx];
}
C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum;
}
}
}
return result;
}
hipError_t run_batched_gemm(Options &options) {
std::cout << "Running strided batched gemm" << std::endl;
// Arbitrary problem size
int m = options.problem_size.m();
int n = options.problem_size.n();
int k = options.problem_size.k();
int batch_count = options.batch_count;
// alpha and beta
float alpha = options.alpha;
float beta = options.beta;
// A, B are non-transpose, column major
int const lda = m;
int const ldb = k * batch_count;
int const ldc = m;
int const count_A = batch_count * lda * k;
int const count_B = ldb * n;
int const count_C = batch_count * ldc * n;
// the memory is batched along K dimension
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
long long int batch_stride_B = static_cast<long long int>(k);
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
hipError_t result = hipSuccess;
// allocate the host memory
std::vector<float> host_A(count_A);
std::vector<float> host_B(count_B);
std::vector<float> host_C(count_C);
std::vector<float> result_C(count_C);
// allocate the device memory
float *A;
float *B;
float *C;
result = hipMalloc(&A, count_A * sizeof(float));
if (result != hipSuccess) {
std::cerr << "hipMalloc result = " << result << std::endl;
return result;
}
result = hipMalloc(&B, count_B * sizeof(float));
if (result != hipSuccess) {
std::cerr << "hipMalloc result = " << result << std::endl;
return result;
}
result = hipMalloc(&C, count_C * sizeof(float));
if (result != hipSuccess) {
std::cerr << "hipMalloc result = " << result << std::endl;
return result;
}
// Limit range to avoid floating-point errors
int const kRange = 8;
// fill A
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
for (int col_idx = 0; col_idx < k; col_idx++) {
for (int row_idx = 0; row_idx < m; row_idx++) {
host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>((row_idx + col_idx * lda + b_idx * lda * k) % kRange);
}
}
}
// fill B
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
for (int col_idx = 0; col_idx < n; col_idx++) {
for (int row_idx = 0; row_idx < k; row_idx++) {
host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(((n + k * ldb + batch_count * k) - (row_idx + col_idx * ldb + b_idx * k)) % kRange);
}
}
}
// fill C
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
for (int col_idx = 0; col_idx < n; col_idx++) {
for (int row_idx = 0; row_idx < m; row_idx++) {
host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f;
}
}
}
// ref memory
std::vector<float> ref_A(host_A);
std::vector<float> ref_B(host_B);
std::vector<float> ref_C(host_C);
// copy host memory to device
result = hipMemcpy(A, host_A.data(), count_A * sizeof(float), hipMemcpyHostToDevice);
if (result != hipSuccess) {
std::cerr << "hipMemcpy result = " << result << std::endl;
return result;
}
result = hipMemcpy(B, host_B.data(), count_B * sizeof(float), hipMemcpyHostToDevice);
if (result != hipSuccess) {
std::cerr << "hipMemcpy result = " << result << std::endl;
return result;
}
result = hipMemcpy(C, host_C.data(), count_C * sizeof(float), hipMemcpyHostToDevice);
if (result != hipSuccess) {
std::cerr << "hipMemcpy result = " << result << std::endl;
return result;
}
result = hytlass_strided_batched_sgemm(
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
beta, batch_count);
if (result != hipSuccess)
return result;
// copy device memory to host
result = hipMemcpy(result_C.data(), C, count_C * sizeof(float), hipMemcpyDeviceToHost);
if (result != hipSuccess) {
std::cerr << "hipMemcpy result = " << result << std::endl;
return result;
}
//compare with reference code
result = strided_batched_gemm_nn_reference(m, n, k, alpha, ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C,
beta, batch_count);
if (result != 0)
return result;
// Expect bit-level accuracy for this simple example
if (ref_C != result_C) {
std::cout << "HYTLASS strided batched gemm does not run correctly" << std::endl;
return hipErrorUnknown;
}
// free memory
result = hipFree(A);
if (result != hipSuccess) {
std::cerr << "hipFree result = " << result << std::endl;
return result;
}
result = hipFree(B);
if (result != hipSuccess) {
std::cerr << "hipFree result = " << result << std::endl;
return result;
}
result = hipFree(C);
if (result != hipSuccess) {
std::cerr << "hipFree result = " << result << std::endl;
return result;
}
return result;
}
int main(int argc, const char **argv) {
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
printf("%d x %d x %d x %d tensor op Matrix Multiply\n", \
options.problem_size.m(), options.problem_size.n(), options.problem_size.k(), options.batch_count);
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
hipError_t result = hipSuccess;
result = run_batched_gemm(options);
if (result == hipSuccess) {
std::cout << "Passed." << std::endl;
}
// Exit.
return result == hipSuccess ? 0 : -1;
}
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_group_gemm
gfx928_group_gemm.cu
)
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM Grouped Example.
This workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices
in Global Memory are passed to the kernel in array (also held in Global Memory). Similarly,
leading dimensions and problem sizes are stored in arrays in GMEM.
This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM
concept may be distinct.
This benchmark program initializes a workspace with random problem sizes for a given number of
groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to
model problems more similar to the traditional batched GEMM.
Additionally, problem sizes are collected and binned to compute the same problem as a series of
conventional batched GEMMs (setup for this problem is not timed). This demonstrates the performance
enhancement achieved by implementing a specialized grouped GEMM kernel.
Examples:
# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)
$ ./gfx928_group_gemm --groups=100 --k=1024 --verbose=true
# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)
$ ./gfx928_group_gemm --groups=100 --k=1024 --verbose=true
# Runs a grouped GEMM that is equivalent to a batched GEMM
$ ./gfx928_group_gemm --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <chrono>
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>
#include <unordered_map>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/gemm.h"
#include "hytlass/gemm/kernel/gemm_grouped.h"
#include "hytlass/gemm/kernel/default_gemm_grouped.h"
#include "hytlass/gemm/device/gemm_grouped.h"
#include "hytlass/gemm/device/gemm_universal.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/distribution.h"
#include "hytlass/util/device_memory.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/host/gemm_complex.h"
#include "hytlass/util/reference/device/gemm_complex.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/device/tensor_fill.h"
#include "hytlass/util/reference/host/tensor_norm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double initialization_time_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double initialization_time_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops),
status(status), error(error), passed(true) { }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Hash function for hytlass::gemm::GemmCoord
struct HashGemmCoord {
size_t operator()(hytlass::gemm::GemmCoord const &problem) const {
std::hash<int> hasher;
return (hasher(problem.m() * 3)) ^ (hasher(1 + problem.n() * 5)) ^ (hasher(2 + problem.k() * 7));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
bool profile_initialization;
bool sort_problems;
std::vector<hytlass::gemm::GemmCoord> problem_sizes;
// problem size bins
std::unordered_map<
hytlass::gemm::GemmCoord,
std::vector<int32_t>,
HashGemmCoord> problem_bins;
int alignment;
int problem_count;
int iterations;
int hip_streams;
bool verbose;
float alpha;
float beta;
std::string benchmark_path;
std::string output_tag;
std::ofstream output_file;
using GroupScheduleMode = hytlass::gemm::kernel::GroupScheduleMode;
std::vector<GroupScheduleMode> scheduler_modes;
std::unordered_map<std::string, GroupScheduleMode>
str_to_scheduler_mode = {
{"kDeviceOnly", GroupScheduleMode::kDeviceOnly},
{"kHostPrecompute", GroupScheduleMode::kHostPrecompute}
};
struct GroupScheduleModeHash {
size_t operator()(GroupScheduleMode m) const {
return static_cast<size_t>(m);
}
};
std::unordered_map<GroupScheduleMode, std::string, GroupScheduleModeHash>
scheduler_mode_to_str = {
{GroupScheduleMode::kDeviceOnly, "kDeviceOnly"},
{GroupScheduleMode::kHostPrecompute, "kHostPrecompute"}
};
std::vector<GroupScheduleMode> all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute};
//
// Methods
//
Options():
help(false),
error(false),
alignment(8),
reference_check(true),
profile_initialization(false),
sort_problems(false),
problem_count(15),
iterations(20),
hip_streams(0),
verbose(false),
alpha(1),
beta(),
scheduler_modes({GroupScheduleMode::kDeviceOnly})
{ }
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("alignment", alignment, 8);
cmd.get_cmd_line_argument("groups", problem_count, 15);
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
cmd.get_cmd_line_argument("beta", beta, 0.0f);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("streams", hip_streams, 0);
cmd.get_cmd_line_argument("verbose", verbose, false);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false);
cmd.get_cmd_line_argument("sort-problems", sort_problems, false);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
std::vector<std::string> scheduler_mode_strs;
cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs);
if (!scheduler_mode_strs.empty()) {
scheduler_modes.clear();
if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") {
scheduler_modes = all_scheduler_modes;
} else {
for (std::string precomp_str : scheduler_mode_strs) {
auto it = str_to_scheduler_mode.find(precomp_str);
if (it != str_to_scheduler_mode.end()) {
scheduler_modes.push_back(it->second);
} else if (precomp_str == "all") {
std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl;
error = true;
return;
} else {
std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl;
error = true;
return;
}
}
}
}
std::string output_path;
cmd.get_cmd_line_argument("tag", output_tag);
cmd.get_cmd_line_argument("output_file", output_path);
if (!output_path.empty()) {
std::ios_base::openmode open_mode = std::ios_base::out;
std::ifstream input_file(output_path.c_str());
if (input_file.good()) {
open_mode = std::ios_base::app;
input_file.close();
}
output_file.open(output_path.c_str(), open_mode);
if (output_file.good() && open_mode != std::ios_base::app) {
output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n";
}
}
// Decide how to initialize the problems
if (!benchmark_path.empty()) {
if (!benchmark_problems()) {
error = true;
problem_sizes.clear();
return;
}
}
else {
randomize_problems(cmd);
}
// Post-process the problem sizes
bin_problems();
}
void randomize_problems(hytlass::CommandLine &cmd) {
//
// For now, randomly choose the problem sizes.
//
int cmd_line_m = -1;
int cmd_line_n = -1;
int cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes.reserve(problem_count);
for (int i = 0; i < problem_count; ++i) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = alignment * ((rand() % 256) + 1);
}
if (n < 1) {
n = alignment * ((rand() % 256) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 256) + 1);
}
hytlass::gemm::GemmCoord problem(m, n, k);
problem_sizes.push_back(problem);
}
}
/// Load a benchmark
bool benchmark_problems() {
std::ifstream file(benchmark_path);
if (!file.good()) {
return false;
}
while (file.good()) {
int idx = -1;
std::string extent_str;
file >> idx >> extent_str;
if (idx < 0 || extent_str.empty()) {
break;
}
hytlass::gemm::GemmCoord extent;
std::vector<std::string> tokens;
hytlass::CommandLine::tokenize(tokens, extent_str, 'x');
for (int i = 0; i < int(tokens.size()); ++i) {
int x = std::atoi(tokens.at(i).c_str());
// round up
if (x % alignment) {
x += (alignment - (x % alignment));
}
extent.at(i) = x;
}
if (extent.product()) {
problem_sizes.push_back(extent);
}
}
return true;
}
/// Post processes the problems
void bin_problems() {
problem_bins.clear();
problem_count = int(problem_sizes.size());
//
// Insert the problem sizes into a sorted container class. This is *NOT* necessary
// to run the HYTLASS kernel, but it enables the execution of hipblas's batched GEMM.
//
for (int i = 0; i < int(problem_sizes.size()); ++i) {
auto it = problem_bins.find(problem_sizes.at(i));
if (it == problem_bins.end()) {
problem_bins.insert({problem_sizes.at(i), std::vector<int32_t>({i}) });
}
else {
it->second.push_back(i);
}
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "05_hytlass_group_gemm\n\n"
<< " This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM\n"
<< " in that multiple, independent GEMMs are computed by one grid launch. It differs in that each\n"
<< " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n"
<< " in device Global Memory and loaded by the kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --benchmark=<str> Executes a benchmark problem size.\n"
<< " --output_file=<str> Path to a CSV file to output results. If it exists already, results are appended.\n"
<< " --tag=<str> String tag to prepend to the CSV file.\n"
<< " --groups=<int> Number of individual GEMM problems (default: --groups=15)\n"
<< " --m=<int> Sets the M dimension for all groups. Otherwise, it is selected randomly\n"
<< " --n=<int> Sets the N dimension for all groups. Otherwise, it is selected randomly\n"
<< " --k=<int> Sets the K dimension for all groups. Otherwise, it is selected randomly\n"
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
<< " --beta=<f32> Epilogue scalar beta (real part)\n"
<< " --scheduler-modes=<str> List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --reference-check=<bool> If true, performs reference check.\n"
<< " --verbose=<bool> If true, prints problem sizes and batching structure.\n"
<< " --profile-initialization=<bool> If true, profiles the device-level kernel's initialization.\n"
<< " --sort-problems=<bool> If true, sorts problem sizes in descending order of GEMM-K dimension.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a grouped GEMM with 100 random problem sizes\n"
<< "$ ./examples/05_hytlass_group_gemm/gfx928_group_gemm --groups=100\n\n"
<< "# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)\n"
<< "$ ./examples/05_hytlass_group_gemm/gfx928_group_gemm --groups=100 --k=1024 --verbose=true\n\n"
<< "# Runs a grouped GEMM that is equivalent to a batched GEMM\n"
<< "$ ./examples/05_hytlass_group_gemm/gfx928_group_gemm --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n"
<< "# Runs a grouped GEMM with each different scheduler mode\n"
<< "$ ./examples/05_hytlass_group_gemm/gfx928_group_gemm --scheduler-modes=all\n\n"
<< "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n"
<< "$ ./examples/05_hytlass_group_gemm/gfx928_group_gemm --scheduler-modes=all --profile-initialization=true\n\n"
<< "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n"
<< "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n"
<< "#\n"
<< "# For example, assume the following are the contents of 'problems.txt'\n"
<< "#\n"
<< "# 0 1024x256x520\n"
<< "# 1 520x264x1024\n"
<< "# 2 96x48x1024\n"
<< "#\n"
<< "$ ./examples/05_hytlass_group_gemm/gfx928_group_gemm --benchmark=problems.txt\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = int64_t();
for (auto const & problem : problem_sizes) {
fmas += problem.product();
}
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
class BaseTestbed {
public:
//
// Type definitions
//
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementC = typename Gemm::ElementC;
using ElementAccumulator = typename Gemm::ElementAccumulator;
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
using ElementCompute = typename EpilogueOutputOp::ElementCompute;
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using MatrixCoord = typename LayoutC::TensorCoord;
//
// Data members
//
Options & options;
/// Initialization
hytlass::Distribution::Kind init_A;
hytlass::Distribution::Kind init_B;
hytlass::Distribution::Kind init_C;
uint32_t seed;
hytlass::DeviceAllocation<hytlass::gemm::GemmCoord> problem_sizes_device;
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
std::vector<int64_t> lda_host;
std::vector<int64_t> ldb_host;
std::vector<int64_t> ldc_host;
std::vector<int64_t> ldd_host;
hytlass::DeviceAllocation<int64_t> lda;
hytlass::DeviceAllocation<int64_t> ldb;
hytlass::DeviceAllocation<int64_t> ldc;
hytlass::DeviceAllocation<int64_t> ldd;
hytlass::DeviceAllocation<ElementA> block_A;
hytlass::DeviceAllocation<ElementB> block_B;
hytlass::DeviceAllocation<ElementC> block_C;
hytlass::DeviceAllocation<ElementC> block_D;
hytlass::DeviceAllocation<ElementA *> ptr_A;
hytlass::DeviceAllocation<ElementB *> ptr_B;
hytlass::DeviceAllocation<ElementC *> ptr_C;
hytlass::DeviceAllocation<ElementC *> ptr_D;
BaseTestbed(
Options &options_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint32_t seed_ = 3080
):
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
int problem_count() const {
return options.problem_count;
}
/// Helper to initialize a tensor view
template <typename Element>
void initialize_tensor(
Element *ptr,
size_t capacity,
hytlass::Distribution::Kind dist_kind,
uint32_t seed) {
if (dist_kind == hytlass::Distribution::Uniform) {
Element scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
int bits_output = hytlass::sizeof_bits<typename Gemm::ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
if (hytlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
} else {
scope_max = 8;
scope_min = -8;
}
hytlass::reference::device::BlockFillRandomUniform(
ptr, capacity, seed, scope_max, scope_min, 0);
}
else if (dist_kind == hytlass::Distribution::Gaussian) {
hytlass::reference::device::BlockFillRandomGaussian(
ptr, capacity, seed, Element(), Element(0.5f));
}
else if (dist_kind == hytlass::Distribution::Sequential) {
// Fill with increasing elements
hytlass::reference::device::BlockFillSequential(
ptr, capacity, Element(1), Element());
}
else {
// Fill with all 1s
hytlass::reference::device::BlockFillSequential(
ptr, capacity, Element(), Element(1));
}
}
/// Allocates device-side data
void allocate() {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
lda_host.resize(problem_count());
ldb_host.resize(problem_count());
ldc_host.resize(problem_count());
ldd_host.resize(problem_count());
for (int32_t i = 0; i < problem_count(); ++i) {
auto problem = options.problem_sizes.at(i);
lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0);
ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0);
ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0);
ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0);
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
int64_t elements_A = problem.m() * problem.k();
int64_t elements_B = problem.k() * problem.n();
int64_t elements_C = problem.m() * problem.n();
int64_t elements_D = problem.m() * problem.n();
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
}
lda.reset(problem_count());
ldb.reset(problem_count());
ldc.reset(problem_count());
ldd.reset(problem_count());
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
}
/// Initializes device-side data
void initialize() {
problem_sizes_device.reset(problem_count());
problem_sizes_device.copy_from_host(options.problem_sizes.data());
lda.copy_from_host(lda_host.data());
ldb.copy_from_host(ldb_host.data());
ldc.copy_from_host(ldc_host.data());
ldd.copy_from_host(ldd_host.data());
//
// Assign pointers
//
std::vector<ElementA *> ptr_A_host(problem_count());
std::vector<ElementB *> ptr_B_host(problem_count());
std::vector<ElementC *> ptr_C_host(problem_count());
std::vector<ElementC *> ptr_D_host(problem_count());
for (int32_t i = 0; i < problem_count(); ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
ptr_A.reset(problem_count());
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(problem_count());
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(problem_count());
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(problem_count());
ptr_D.copy_from_host(ptr_D_host.data());
//
// Initialize the problems of the workspace
//
initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021);
initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022);
initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023);
hytlass::reference::device::BlockFillSequential(
block_D.get(), block_D.size(), ElementC(), ElementC());
}
/// Verifies the result is a GEMM
bool verify() {
bool passed = true;
for (int32_t i = 0; i < problem_count(); ++i) {
hytlass::gemm::GemmCoord problem = options.problem_sizes.at(i);
LayoutA layout_A(lda_host.at(i));
LayoutB layout_B(ldb_host.at(i));
LayoutC layout_C(ldc_host.at(i));
LayoutC layout_D(ldd_host.at(i));
MatrixCoord extent_A{problem.m(), problem.k()};
MatrixCoord extent_B{problem.k(), problem.n()};
MatrixCoord extent_C{problem.m(), problem.n()};
hytlass::TensorView<ElementA, LayoutA> view_A(block_A.get() + offset_A.at(i), layout_A, extent_A);
hytlass::TensorView<ElementB, LayoutB> view_B(block_B.get() + offset_B.at(i), layout_B, extent_B);
hytlass::TensorView<ElementC, LayoutC> view_C(block_C.get() + offset_C.at(i), layout_C, extent_C);
hytlass::DeviceAllocation<ElementC> block_Ref(layout_D.capacity(extent_C));
hytlass::TensorView<ElementC, LayoutC> view_Ref_device(block_Ref.get(), layout_D, extent_C);
// Reference GEMM
hytlass::reference::device::GemmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute, ElementAccumulator
>(
problem,
options.alpha,
view_A,
Gemm::kTransformA,
view_B,
Gemm::kTransformB,
options.beta,
view_C,
view_Ref_device,
ElementAccumulator(0)
);
// Copy to host memory
std::vector<ElementC> matrix_D(layout_D.capacity(extent_C));
std::vector<ElementC> matrix_Ref(layout_D.capacity(extent_C));
hytlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size());
hytlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size());
hytlass::TensorView<ElementC, LayoutC> view_D( matrix_D.data(), layout_D, extent_C);
hytlass::TensorView<ElementC, LayoutC> view_Ref(matrix_Ref.data(), layout_D, extent_C);
ElementC eps(1e-3);
ElementC non_zero_floor(1e-6);
for (int _i=0; _i<problem.m()*problem.n(); _i++) {
float _diff = std::abs(float(matrix_D.data()[_i] - matrix_Ref.data()[_i]));
if (_diff > (float)eps ) {
printf("diff is %f of %d-th md is %f and ref is %f\n",
_diff, _i, (float)matrix_D.data()[_i], (float)matrix_Ref.data()[_i]);
}
break;
}
// Reference check
passed = hytlass::reference::host::TensorRelativelyEquals(view_D, view_Ref, eps, non_zero_floor);
if (!passed) {
std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl;
return passed;
}
}
return passed;
}
};
template <typename Gemm>
class TestbedBatched : BaseTestbed<Gemm> {
public:
TestbedBatched(
Options &options_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint32_t seed_ = 3080
): BaseTestbed<Gemm>(options_, init_A_, init_B_, init_C_, seed_) {}
void print_problem_sizes() {
std::cout << std::endl;
size_t bin_idx = 0;
size_t problem_count_check = 0;
std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n";
for (auto const & bin : this->options.problem_bins) {
std::cout << " [" << bin_idx << "]: "
<< bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k()
<< ", batch count: " << bin.second.size() << "\n";
++bin_idx;
problem_count_check += bin.second.size();
}
if (problem_count_check != size_t(this->problem_count())) {
std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl;
}
std::cout << std::endl;
}
/// Executes a batched kernel and measures runtime
Result profile() {
std::cout << "Batched GEMM:\n"
<< "====================================================" << std::endl;
Result result;
result.passed = false;
// Initialize the problem
this->allocate();
this->initialize();
if (this->options.verbose) {
print_problem_sizes();
}
//
// Prepare batched GEMM environment
//
int32_t effective_streams = (this->options.hip_streams ? this->options.hip_streams : 1);
// Array of leading dimensions used by batched GEMM calls
std::vector<hytlass::gemm::GemmCoord> bin_problem_sizes;
std::vector<int32_t> bin_count;
std::vector<int32_t> bin_ldm_A;
std::vector<int32_t> bin_ldm_B;
std::vector<int32_t> bin_ldm_C;
std::vector<int32_t> bin_start;
std::vector<void const *> ptr_A_batched_host;
std::vector<void const *> ptr_B_batched_host;
std::vector<void *> ptr_C_batched_host;
for (auto const & bin : this->options.problem_bins) {
int first_idx = bin.second.front();
bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx));
bin_count.push_back(int32_t(bin.second.size()));
bin_ldm_A.push_back(static_cast<int32_t>(this->lda_host.at(first_idx)));
bin_ldm_B.push_back(static_cast<int32_t>(this->ldb_host.at(first_idx)));
bin_ldm_C.push_back(static_cast<int32_t>(this->ldc_host.at(first_idx)));
if (ptr_A_batched_host.size() % 2) {
ptr_A_batched_host.push_back(nullptr);
ptr_B_batched_host.push_back(nullptr);
ptr_C_batched_host.push_back(nullptr);
}
bin_start.push_back(int32_t(ptr_A_batched_host.size()));
for (int idx : bin.second) {
if (bin_problem_sizes.back() != this->options.problem_sizes.at(idx)) {
std::cerr << "Error - failed to group problems.\n";
return result;
}
if (bin_ldm_A.back() != this->lda_host.at(idx)) {
std::cerr << "Error - failed to group problems.\n";
return result;
}
if (bin_ldm_B.back() != this->ldb_host.at(idx)) {
std::cerr << "Error - failed to group problems.\n";
return result;
}
if (bin_ldm_C.back() != this->ldc_host.at(idx)) {
std::cerr << "Error - failed to group problems.\n";
return result;
}
ptr_A_batched_host.push_back(this->block_A.get() + this->offset_A.at(idx));
ptr_B_batched_host.push_back(this->block_B.get() + this->offset_B.at(idx));
ptr_C_batched_host.push_back(this->block_D.get() + this->offset_C.at(idx));
}
}
// Array of GMEM pointers used by batched array GEMM calls
hytlass::DeviceAllocation<void const *> ptr_A_batched;
hytlass::DeviceAllocation<void const *> ptr_B_batched;
hytlass::DeviceAllocation<void *> ptr_C_batched;
ptr_A_batched.reset(ptr_A_batched_host.size());
ptr_B_batched.reset(ptr_A_batched_host.size());
ptr_C_batched.reset(ptr_A_batched_host.size());
ptr_A_batched.copy_from_host(ptr_A_batched_host.data());
ptr_B_batched.copy_from_host(ptr_B_batched_host.data());
ptr_C_batched.copy_from_host(ptr_C_batched_host.data());
//
// Create hip streams to maximize concurrency of batched-array GEMM kernels
//
std::vector<hipStream_t> hip_streams;
//
// Warmup run
//
if (this->options.hip_streams) {
for (int i = 0; i < this->options.hip_streams; ++i) {
hipStream_t stream;
result.error = hipStreamCreate(&stream);
if (result.error != hipSuccess) {
std::cerr << "Failed to create hip stream." << std::endl;
return result;
}
hip_streams.push_back(stream);
}
}
else {
hip_streams.push_back(nullptr);
}
// Use 'D' for the in/out workspace
this->block_D.copy_from_device(this->block_C.get());
for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) {
hytlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx];
int32_t batch_count = bin_count[bin_idx];
int32_t bin_start_idx = bin_start[bin_idx];
int32_t lda = bin_ldm_A[bin_idx];
int32_t ldb = bin_ldm_B[bin_idx];
int32_t ldc = bin_ldm_C[bin_idx];
void const ** ptr_A_array = ptr_A_batched.get() + bin_start[bin_idx];
void const ** ptr_B_array = ptr_B_batched.get() + bin_start[bin_idx];
void ** ptr_C_array = ptr_C_batched.get() + bin_start[bin_idx];
//
// Initialize the HYTLASS GEMM operator
//
// Configure the GEMM arguments
typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta);
typename Gemm::Arguments arguments{
hytlass::gemm::GemmUniversalMode::kArray,
problem,
batch_count,
epilogue_op,
(void const *)ptr_A_array,
(void const *)ptr_B_array,
(void const *)ptr_C_array,
(void *)ptr_C_array,
int64_t(),
int64_t(),
int64_t(),
int64_t(),
int64_t(lda),
int64_t(ldb),
int64_t(ldc),
int64_t(ldc)
};
Gemm gemm_op;
hytlass::Status status = gemm_op.initialize(arguments);
if (status != hytlass::Status::kSuccess) {
std::cerr << "HYTLASS error on line " << __LINE__ << std::endl;
return result;
}
status = gemm_op();
if (status != hytlass::Status::kSuccess) {
std::cerr << "HYTLASS error on line " << __LINE__ << std::endl;
return result;
}
}
//
// Wait for completion
//
result.error = hipDeviceSynchronize();
if (result.error != hipSuccess) {
std::cerr << "Kernel execution error: " << hipGetErrorString(result.error);
return result;
}
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
//
// Wait for completion
//
result.error = hipDeviceSynchronize();
if (result.error != hipSuccess) {
std::cerr << "Kernel execution error: " << hipGetErrorString(result.error);
return result;
}
// Record an event at the start of a series of GEMM operations
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
int last_stream_idx = 0;
for (int iter = 0; iter < this->options.iterations; ++iter) {
for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) {
hytlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx];
int32_t batch_count = bin_count[bin_idx];
int32_t bin_start_idx = bin_start[bin_idx];
int32_t lda = bin_ldm_A[bin_idx];
int32_t ldb = bin_ldm_B[bin_idx];
int32_t ldc = bin_ldm_C[bin_idx];
void const ** ptr_A_array = ptr_A_batched.get() + bin_start[bin_idx];
void const ** ptr_B_array = ptr_B_batched.get() + bin_start[bin_idx];
void ** ptr_C_array = ptr_C_batched.get() + bin_start[bin_idx];
last_stream_idx = (bin_idx % effective_streams);
//
// Initialize the HYTLASS GEMM operator
//
// Configure the GEMM arguments
typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta);
typename Gemm::Arguments arguments{
hytlass::gemm::GemmUniversalMode::kArray,
problem,
batch_count,
epilogue_op,
(void const *)ptr_A_array,
(void const *)ptr_B_array,
(void const *)ptr_C_array,
(void *)ptr_C_array,
int64_t(),
int64_t(),
int64_t(),
int64_t(),
int64_t(lda),
int64_t(ldb),
int64_t(ldc),
int64_t(ldc)
};
Gemm gemm_op;
hytlass::Status status = gemm_op.initialize(arguments);
if (status != hytlass::Status::kSuccess) {
std::cerr << "HYTLASS error on line " << __LINE__ << std::endl;
return result;
}
status = gemm_op(hip_streams[last_stream_idx]);
if (status != hytlass::Status::kSuccess) {
std::cerr << "HYTLASS error on line " << __LINE__ << std::endl;
return result;
}
}
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
//
// Wait for work to be completed
//
result.error = hipDeviceSynchronize();
if (result.error != hipSuccess) {
std::cerr << "Kernel execution error: " << hipGetErrorString(result.error);
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(this->options.iterations);
result.gflops = this->options.gflops(result.runtime_ms / 1000.0);
//
// Cleanup
//
for (auto event : events) {
(void)hipEventDestroy(event);
}
for (auto stream : hip_streams) {
if (stream) {
(void)hipStreamDestroy(stream);
}
}
std::cout << " " << this->options.problem_bins.size() << " batched GEMMs launched" << std::endl;
std::cout << std::endl;
std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " " << "Batched GFLOPs: " << result.gflops << std::endl;
std::string provider = "HYTLASS";
if (this->options.output_file.good()) {
this->options.output_file << this->options.output_tag << "," << provider << ",batched,"
<< this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl;
}
result.passed = true;
return result;
}
};
template <typename Gemm_, hytlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_>
class TestbedGrouped : BaseTestbed<Gemm_> {
public:
TestbedGrouped(
Options &options_,
hytlass::Distribution::Kind init_A_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_B_ = hytlass::Distribution::Uniform,
hytlass::Distribution::Kind init_C_ = hytlass::Distribution::Uniform,
uint32_t seed_ = 3080
): BaseTestbed<Gemm_>(options_, init_A_, init_B_, init_C_, seed_) {}
// Redefine GEMM with different GroupScheduleMode_
using GemmKernel = typename hytlass::gemm::kernel::DefaultGemmGrouped<
typename Gemm_::ElementA,
typename Gemm_::LayoutA,
Gemm_::kTransformA,
Gemm_::kAlignmentA,
typename Gemm_::ElementB,
typename Gemm_::LayoutB,
Gemm_::kTransformB,
Gemm_::kAlignmentB,
typename Gemm_::ElementC,
typename Gemm_::LayoutC,
typename Gemm_::ElementAccumulator,
typename Gemm_::OperatorClass,
typename Gemm_::ArchTag,
typename Gemm_::ThreadblockShape,
typename Gemm_::WarpShape,
typename Gemm_::InstructionShape,
typename Gemm_::EpilogueOutputOp,
typename Gemm_::ThreadblockSwizzle,
Gemm_::kStages,
GroupScheduleMode_>::GemmKernel;
using Gemm = hytlass::gemm::device::GemmGrouped<GemmKernel>;
/// Verbose printing of problem sizes
void print_problem_sizes() {
std::cout << std::endl;
// Print groups
std::cout << this->problem_count() << " groups:\n";
int32_t idx = 0;
int64_t total_tiles = 0;
for (auto const & problem : this->options.problem_sizes) {
int tiles = Gemm::problem_tile_count(problem);
total_tiles += tiles;
std::cout << " [" << idx << "]: "
<< problem.m() << "-by-" << problem.n() << "-by-" << problem.k()
<< " (" << tiles << " threadblock tiles)" << "\n";
++idx;
}
std::cout << std::endl;
}
/// Sort problems in descending order of problem-K dimension
void sort_problems() {
Gemm::sort_problems(this->options.problem_count,
this->options.problem_sizes.data(),
this->lda_host.data(),
this->ldb_host.data(),
this->ldc_host.data(),
this->ldd_host.data(),
this->offset_A.data(),
this->offset_B.data(),
this->offset_C.data(),
this->offset_D.data());
}
/// Executes a grouped kernel and measures runtime
Result profile() {
std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second;
std::cout << std::endl;
std::cout << "Grouped GEMM (HYTLASS) with mode " << sched_mode << ":\n"
<< "====================================================" << std::endl;
Result result;
int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count);
// Early exit
if (!threadblock_count) {
std::cout << "Active hip device lacks hardware resources to run HYTLASS Grouped GEMM kernel." << std::endl;
return result;
}
result.passed = false;
// Initialize the problem
this->allocate();
if (this->options.sort_problems) {
sort_problems();
}
this->initialize();
if (this->options.verbose) {
print_problem_sizes();
}
// Configure the GEMM arguments
typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta);
// Configure GEMM arguments
typename Gemm::Arguments args(
this->problem_sizes_device.get(),
this->problem_count(),
threadblock_count,
epilogue_op,
this->ptr_A.get(),
this->ptr_B.get(),
this->ptr_C.get(),
this->ptr_D.get(),
this->lda.get(),
this->ldb.get(),
this->ldc.get(),
this->ldd.get(),
this->options.problem_sizes.data()
);
// Initialize the GEMM object
Gemm gemm;
size_t workspace_size = gemm.get_workspace_size(args);
hytlass::DeviceAllocation<uint8_t> workspace(workspace_size);
result.status = gemm.initialize(args, workspace.get());
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to initialize HYTLASS Grouped GEMM kernel." << std::endl;
return result;
}
// Run the grouped GEMM object
result.status = gemm.run();
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to run HYTLASS Grouped GEMM kernel." << std::endl;
return result;
}
// Wait for completion
result.error = hipDeviceSynchronize();
if (result.error != hipSuccess) {
std::cerr << "Kernel execution error: " << hipGetErrorString(result.error);
return result;
}
//
// Verify correctness
//
result.passed = true;
if (this->options.reference_check) {
result.passed = this->verify();
}
//
// Warm-up run of the grouped GEMM object
//
result.status = gemm.run();
if (result.status != hytlass::Status::kSuccess) {
std::cerr << "Failed to run HYTLASS Grouped GEMM kernel." << std::endl;
return result;
}
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMM operations
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
//
// Run profiling loop
//
for (int iter = 0; iter < this->options.iterations; ++iter) {
gemm();
}
//
// Stop profiling loop
//
// Record an event when the GEMM operations have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(this->options.iterations);
result.gflops = this->options.gflops(result.runtime_ms / 1000.0);
//
// Cleanup
//
for (auto event : events) {
(void)hipEventDestroy(event);
}
// Optionally profile initialization
if (this->options.profile_initialization) {
// Warm up
gemm.initialize(args, workspace.get());
auto start_time = std::chrono::high_resolution_clock::now();
for (int32_t i = 0; i < this->options.iterations; ++i) {
gemm.initialize(args, workspace.get());
}
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::milli> duration = end_time - start_time;
duration /= double(this->options.iterations);
result.initialization_time_ms = duration.count();
}
int64_t total_tiles = Gemm::group_tile_count(args);
std::cout << " " << total_tiles << " total threadblock tiles." << std::endl;
std::cout << std::endl;
std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl;
if (this->options.profile_initialization) {
std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl;
}
if (this->options.output_file.good()) {
this->options.output_file << this->options.output_tag << ",HYTLASS,grouped-" << sched_mode << ","
<< this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl;
}
std::cout << "\nPassed\n";
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
//
// Define the Grouped and Batched GEMM types
//
using ElementA = float;
using ElementB = float;
using ElementOutput = float;
using ElementAccumulator = float;
using LayoutA = hytlass::layout::ColumnMajor;
using LayoutB = hytlass::layout::RowMajor;
using LayoutC = hytlass::layout::ColumnMajor;
// Gemm operator hytlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8
using GemmBatched = hytlass::gemm::device::GemmUniversal<
ElementA, LayoutA,
ElementB, LayoutB,
ElementOutput, LayoutC,
ElementAccumulator,
hytlass::arch::OpClassTensorOp,
hytlass::arch::Gfx928,
hytlass::gemm::GemmShape<128, 128, 32>,
hytlass::gemm::GemmShape<64, 64, 32>,
hytlass::gemm::GemmShape<16, 16, 8>,
hytlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / hytlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
1
>;
// Define a grouped GEMM kernel with all template parameters set except
// for scheduling mode. This will be used as the template for all scheduling
// modes executed.
const static int kAlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value;
const static int kAlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value;
using GemmKernel = typename hytlass::gemm::kernel::DefaultGemmGrouped<
ElementA,
LayoutA,
hytlass::ComplexTransform::kNone,
kAlignmentA,
ElementB,
LayoutB,
hytlass::ComplexTransform::kNone,
kAlignmentB,
ElementOutput, LayoutC,
ElementAccumulator,
hytlass::arch::OpClassTensorOp,
hytlass::arch::Gfx928,
hytlass::gemm::GemmShape<128, 128, 32>,
hytlass::gemm::GemmShape<64, 64, 32>,
hytlass::gemm::GemmShape<16, 16, 8>,
hytlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / hytlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
// NOTE: Threadblock swizzling is currently not supported by HYTLASS's grouped kernels.
// This parameter is passed in at present to match the APIs of other kernels. The parameter
// is unused within the kernel.
hytlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
4>::GemmKernel;
using GemmGrouped = hytlass::gemm::device::GemmGrouped<GemmKernel>;
//
// Profile it
//
TestbedBatched<GemmBatched> testbed_batched(options);
Result result = testbed_batched.profile();
if (result.error) {
return 1;
}
using GroupScheduleMode = hytlass::gemm::kernel::GroupScheduleMode;
for (GroupScheduleMode mode : options.scheduler_modes) {
Result result;
switch (mode) {
case GroupScheduleMode::kDeviceOnly:
{
TestbedGrouped<GemmGrouped, GroupScheduleMode::kDeviceOnly> runner(options);
result = runner.profile();
break;
}
case GroupScheduleMode::kHostPrecompute:
{
TestbedGrouped<GemmGrouped, GroupScheduleMode::kHostPrecompute> runner(options);
result = runner.profile();
break;
}
}
if (result.error != hipSuccess) {
return 1;
}
// Override verbose flag to avoid printing duplicate information for each scheduling mode
options.verbose = false;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
hute_gfx928_streamk_gemm
hute_gfx928_streamk_gemm.cu
)
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hute/tensor.hpp"
#include "hute/atom/mma_atom.hpp"
#include "hytlass/numeric_types.h"
#include "hytlass/matrix_coord.h"
#include "hytlass/gemm/dispatch_policy.hpp"
#include "hytlass/gemm/device/gemm_universal_adapter.h"
#include "hytlass/gemm/kernel/gemm_universal.hpp"
#include "hytlass/gemm/kernel/tile_scheduler.hpp"
#include "hytlass/gemm/collective/collective_builder.hpp"
#include "hytlass/epilogue/collective/default_epilogue.hpp"
#include "hytlass/epilogue/collective/collective_builder.hpp"
#include "hytlass/epilogue/thread/linear_combination.h"
#include "hytlass/util/packed_stride.hpp"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/device/gett.hpp"
#include "helper.h"
using namespace hute;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
int batch_count;
int split_k_slices;
bool deterministic;
float alpha;
float beta;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({8192, 8192, 2048}),
batch_count(1),
split_k_slices(1),
deterministic(false),
reference_check(false),
iterations(10),
alpha(1),
beta()
{}
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("batch_count", batch_count);
cmd.get_cmd_line_argument("split_k_slices", split_k_slices);
cmd.get_cmd_line_argument("deterministic", deterministic);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "06_hute_streamk example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n"
<< " --batch_count=<int> Batch number\n"
<< " --split_k_slices=<int> Split-K factor to emulate\n\n"
<< " --deterministic=<int> Reduction Mode.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/06_hute_streamk/gfx928_streamk_gemm --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product() * batch_count;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
using ElementA = hytlass::half_t;
using LayoutA = hytlass::layout::ColumnMajor;
using ElementB = hytlass::half_t;
using LayoutB = hytlass::layout::RowMajor;
using ElementC = hytlass::half_t;
using LayoutC = hytlass::layout::ColumnMajor;
using ElementAccumulator = float;
using ElementCompute = ElementAccumulator;
constexpr int AlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value;
constexpr int AlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value;
constexpr int AlignmentC = 128 / hytlass::sizeof_bits<ElementC>::value;
using TileShape_MNK = Shape<_128, _128, _32>;
using WarpShape_MNK = Shape<_2, _2, _1>;
using InstructionShape_MNK = Shape<_32, _32, _16>;
using ClusterShape_MNK = Shape<_1, _1, _1>;
using StageCountType = hytlass::gemm::collective::StageCount<2>; // 指定为twoStage
using KernelScheduleType = hytlass::gemm::KernelStreamKSpecialized;
using EpilogueTileType = hytlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = hytlass::epilogue::collective::EpilogueScheduleAuto;
using TileSchedulerType = hytlass::gemm::StreamKScheduler;
using CollectiveMainloop = typename hytlass::gemm::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape_MNK, WarpShape_MNK, InstructionShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType
>::CollectiveOp;
using CollectiveEpilogue = typename hytlass::epilogue::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueScheduleType
>::CollectiveOp;
using GemmKernel = hytlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType
>;
using Gemm = hytlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using ElementD = typename Gemm::GemmKernel::ElementD;
using StrideD = typename Gemm::GemmKernel::StrideD;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using LayoutTagA = hytlass::detail::StrideToLayoutTagA_t<StrideA>;
using LayoutTagB = hytlass::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = hytlass::detail::StrideToLayoutTagA_t<StrideC>;
using LayoutTagD = hytlass::detail::StrideToLayoutTagA_t<StrideD>;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using RasterOrderOptions = typename Gemm::GemmKernel::TileScheduler::RasterOrderOptions;
using ReductionMode = typename Gemm::GemmKernel::TileScheduler::ReductionMode;
int run(Options &options) {
int m = options.problem_size.m();
int n = options.problem_size.n();
int k = options.problem_size.k();
int batch_count = options.batch_count;
int splits = options.split_k_slices;
bool deterministic = options.deterministic;
float alpha = options.alpha;
float beta = options.beta;
StrideA stride_a = hytlass::make_hute_packed_stride(StrideA{}, hute::make_shape(m, k, batch_count));
StrideB stride_b = hytlass::make_hute_packed_stride(StrideB{}, hute::make_shape(n, k, batch_count));
StrideC stride_c = hytlass::make_hute_packed_stride(StrideC{}, hute::make_shape(m, n, batch_count));
StrideD stride_d = hytlass::make_hute_packed_stride(StrideD{}, hute::make_shape(m, n, batch_count));
auto a_coord = hytlass::make_Coord(m * batch_count, k);
auto b_coord = hytlass::make_Coord(k, n * batch_count);
auto c_coord = hytlass::make_Coord(m * batch_count, n);
typename LayoutTagA::Stride stride_factor_A;
typename LayoutTagB::Stride stride_factor_B;
typename LayoutTagC::Stride stride_factor_C;
typename LayoutTagD::Stride stride_factor_D;
hytlass::HostTensor<ElementA, LayoutTagA> tensor_A;
hytlass::HostTensor<ElementB, LayoutTagB> tensor_B;
hytlass::HostTensor<ElementC, LayoutTagC> tensor_C;
hytlass::HostTensor<ElementC, LayoutTagD> tensor_D;
hytlass::HostTensor<ElementC, LayoutTagD> reference_D;
ProblemShapeType problem_size = ProblemShapeType{m, n, k, batch_count};
tensor_A.resize(a_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_B.resize(b_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
tensor_C.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
tensor_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
reference_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
auto A = hute::make_tensor(tensor_A.host_data(),
hute::make_layout(hute::make_shape(m, k, batch_count), stride_a));
auto B = hute::make_tensor(tensor_B.host_data(),
hute::make_layout(hute::make_shape(n, k, batch_count), stride_b));
auto C = hute::make_tensor(tensor_C.host_data(),
hute::make_layout(hute::make_shape(m, n, batch_count), stride_c));
auto D = hute::make_tensor(reference_D.host_data(),
hute::make_layout(hute::make_shape(m, n, batch_count), stride_d));
hytlass::reference::host::TensorFillRandomUniform(
tensor_A.host_view(),
6118,
ElementA(5),
ElementA(-5),
0); // <- Fill matrix A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_B.host_view(),
6117,
ElementB(5),
ElementB(-5),
0); // <- Fill matrix B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_C.host_view(),
6116,
ElementC(5),
ElementC(-5),
0); // <- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFill(
tensor_D.host_view()); // <- fill matrix D on host with zeros
hytlass::reference::host::TensorFill(
reference_D.host_view()); // <- fill matrix D for reference on host with zeros
hytlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
hytlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = hytlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args;
if constexpr (std::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, hytlass::gemm::StreamKScheduler>) {
scheduler_args = {
splits, RasterOrderOptions::Heuristic,
deterministic ? ReductionMode::Deterministic : ReductionMode::Nondeterministic
};
}
auto arguments = typename Gemm::Arguments {
hytlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{
tensor_A.device_data(), stride_a,
tensor_B.device_data(), stride_b
},
{
{hytlass::from_real<ElementScalar>(alpha), hytlass::from_real<ElementScalar>(beta)},
tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d
},
hw_info,
scheduler_args
};
size_t workspace_size = Gemm::get_workspace_size(arguments);
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
hytlass::Status status = gemm_op.can_implement(arguments);
HYTLASS_CHECK(status);
// Run the GEMM
status = gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(status);
Result result;
// Launch initialized HYTLASS kernel
status = gemm_op.run();
(void)hipDeviceSynchronize();
HYTLASS_CHECK(status);
// verify
hytlass::reference::device::gett<
ProblemShapeType, ElementA, StrideA, ElementB, StrideB,
ElementAccumulator, ElementC, StrideC, ElementD, StrideD, ElementScalar>(
problem_size, tensor_A.device_data(), stride_a, tensor_B.device_data(),
stride_b, ElementAccumulator(0), tensor_C.device_data(), stride_c,
reference_D.device_data(), stride_d, alpha, beta);
tensor_D.sync_host();
reference_D.sync_host();
result.passed = hytlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
if (result.passed) {
std::cout << "Reference check passed." << std::endl;
}
else {
std::cerr << "Error - reference check failed." << std::endl;
}
// warmup
for (int iter = 0; iter < 10; ++iter) {
status = gemm_op(arguments, workspace.get());
}
(void)hipDeviceSynchronize();
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized HYTLASS kernel
status = gemm_op(arguments, workspace.get());
HYTLASS_CHECK(status);
}
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << "GFLOPs: " << result.gflops << std::endl;
return 0;
}
int main(int argc, const char **argv) {
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
printf("%d x %d x %d batch_count=%d splits=%d tensor op Matrix Multiply\n", \
options.problem_size.m(), options.problem_size.n(), options.problem_size.k(),options.batch_count, options.split_k_slices);
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
return run(options);
}
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
hute_gfx928_batch_gemm
hute_gfx928_batch_gemm.cu
)
hytlass_example_add_executable(
hute_gfx928_ptr_array_batched_gemm
hute_gfx928_ptr_array_batched_gemm.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hute/tensor.hpp"
#include "hute/atom/mma_atom.hpp"
#include "hytlass/numeric_types.h"
#include "hytlass/matrix_coord.h"
#include "hytlass/gemm/dispatch_policy.hpp"
#include "hytlass/gemm/device/gemm_universal_adapter.h"
#include "hytlass/gemm/kernel/gemm_universal.hpp"
#include "hytlass/gemm/collective/collective_builder.hpp"
#include "hytlass/epilogue/collective/default_epilogue.hpp"
#include "hytlass/epilogue/collective/collective_builder.hpp"
#include "hytlass/epilogue/thread/linear_combination.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/packed_stride.hpp"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/device/gett.hpp"
#include "helper.h"
using namespace hute;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
int batch_count;
float alpha;
float beta;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({8192, 8192, 2048}),
batch_count(1),
reference_check(false),
iterations(10),
alpha(1),
beta()
{}
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("batch_count", batch_count);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "07_hute_batch_gemm example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/07_hute_batch_gemm/gfx928_batch_gemm --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 --beta=0.707 --batch_count=2 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product() * batch_count;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
using ElementA = hytlass::half_t;
using LayoutA = hytlass::layout::ColumnMajor;
constexpr int AlignmentA = 128 / sizeof_bits_v<ElementA>;;
using ElementB = hytlass::half_t;
using LayoutB = hytlass::layout::RowMajor;
constexpr int AlignmentB = 128 / sizeof_bits_v<ElementB>;;
using ElementC = hytlass::half_t;
using LayoutC = hytlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / sizeof_bits_v<ElementC>;
using ElementAccumulator = float;
using ElementCompute = ElementAccumulator;
using TileShape_MNK = Shape<_128, _128, _32>;
using WarpShape_MNK = Shape<_2, _2, _1>;
using InstructionShape_MNK = Shape<_32, _32, _16>;
using ClusterShape_MNK = Shape<_1, _1, _1>;
using StageCountType = hytlass::gemm::collective::StageCount<2>;
using KernelScheduleType = hytlass::gemm::KernelMultistage;
using EpilogueTileType = hytlass::epilogue::collective::EpilogueTileAuto;
using EpilogueScheduleType = hytlass::epilogue::collective::EpilogueScheduleAuto;
using CollectiveMainloop = typename hytlass::gemm::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape_MNK, WarpShape_MNK, InstructionShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType
>::CollectiveOp;
using CollectiveEpilogue = typename hytlass::epilogue::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueScheduleType
>::CollectiveOp;
using GemmKernel = hytlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = hytlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using ElementD = typename Gemm::GemmKernel::ElementD;
using StrideD = typename Gemm::GemmKernel::StrideD;
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
using LayoutTagA = hytlass::detail::StrideToLayoutTagA_t<StrideA>;
using LayoutTagB = hytlass::detail::StrideToLayoutTagB_t<StrideB>;
using LayoutTagC = hytlass::detail::StrideToLayoutTagA_t<StrideC>;
using LayoutTagD = hytlass::detail::StrideToLayoutTagA_t<StrideD>;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
int run(Options &options) {
int m = options.problem_size.m();
int n = options.problem_size.n();
int k = options.problem_size.k();
int batch_count = options.batch_count;
float alpha = options.alpha;
float beta = options.beta;
StrideA stride_a = hytlass::make_hute_packed_stride(StrideA{}, hute::make_shape(m, k, batch_count));
StrideB stride_b = hytlass::make_hute_packed_stride(StrideB{}, hute::make_shape(n, k, batch_count));
StrideC stride_c = hytlass::make_hute_packed_stride(StrideC{}, hute::make_shape(m, n, batch_count));
StrideD stride_d = hytlass::make_hute_packed_stride(StrideD{}, hute::make_shape(m, n, batch_count));
auto a_coord = hytlass::make_Coord(m * batch_count, k);
auto b_coord = hytlass::make_Coord(k, n * batch_count);
auto c_coord = hytlass::make_Coord(m * batch_count, n);
typename LayoutTagA::Stride stride_factor_A;
typename LayoutTagB::Stride stride_factor_B;
typename LayoutTagC::Stride stride_factor_C;
typename LayoutTagD::Stride stride_factor_D;
hytlass::HostTensor<ElementA, LayoutTagA> tensor_A;
hytlass::HostTensor<ElementB, LayoutTagB> tensor_B;
hytlass::HostTensor<ElementC, LayoutTagC> tensor_C;
hytlass::HostTensor<ElementC, LayoutTagD> tensor_D;
hytlass::HostTensor<ElementC, LayoutTagD> reference_D;
ProblemShapeType problem_size = ProblemShapeType{m, n, k, batch_count};
tensor_A.resize(a_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
tensor_B.resize(b_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
tensor_C.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
tensor_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
reference_D.resize(c_coord, hytlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
auto A = hute::make_tensor(tensor_A.host_data(),
hute::make_layout(hute::make_shape(m, k, batch_count), stride_a));
auto B = hute::make_tensor(tensor_B.host_data(),
hute::make_layout(hute::make_shape(n, k, batch_count), stride_b));
auto C = hute::make_tensor(tensor_C.host_data(),
hute::make_layout(hute::make_shape(m, n, batch_count), stride_c));
auto D = hute::make_tensor(reference_D.host_data(),
hute::make_layout(hute::make_shape(m, n, batch_count), stride_d));
hytlass::reference::host::TensorFillRandomUniform(
tensor_A.host_view(),
6118,
ElementA(5),
ElementA(-5),
0); // <- Fill matrix A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_B.host_view(),
6117,
ElementB(5),
ElementB(-5),
0); // <- Fill matrix B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_C.host_view(),
6116,
ElementC(5),
ElementC(-5),
0); // <- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFill(
tensor_D.host_view()); // <- fill matrix D on host with zeros
hytlass::reference::host::TensorFill(
reference_D.host_view()); // <- fill matrix D for reference on host with zeros
hytlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_C.sync_device();
tensor_D.sync_device();
auto arguments = typename Gemm::Arguments {
hytlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{
tensor_A.device_data(), stride_a,
tensor_B.device_data(), stride_b
},
{
{hytlass::from_real<ElementScalar>(alpha), hytlass::from_real<ElementScalar>(beta)},
tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d
}
};
size_t workspace_size = Gemm::get_workspace_size(arguments);
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
Gemm gemm_op;
hytlass::Status status = gemm_op.can_implement(arguments);
HYTLASS_CHECK(status);
// Run the GEMM
status = gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(status);
// Result structure
Result result;
// Launch initialized HYTLASS kernel
status = gemm_op.run();
(void)hipDeviceSynchronize();
HYTLASS_CHECK(status);
// verify
hytlass::reference::device::gett<
ProblemShapeType, ElementA, StrideA, ElementB, StrideB,
ElementAccumulator, ElementC, StrideC, ElementD, StrideD, ElementScalar>(
problem_size, tensor_A.device_data(), stride_a, tensor_B.device_data(),
stride_b, ElementAccumulator(0), tensor_C.device_data(), stride_c,
reference_D.device_data(), stride_d, alpha, beta);
tensor_D.sync_host();
reference_D.sync_host();
result.passed = hytlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
if (result.passed) {
std::cout << "Reference check passed." << std::endl;
}
else {
std::cerr << "Error - reference check failed." << std::endl;
}
// warmup
for (int iter = 0; iter < 10; ++iter) {
status = gemm_op(arguments, workspace.get());
}
(void)hipDeviceSynchronize();
//
// Construct events
//
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMMs
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized HYTLASS kernel
status = gemm_op(arguments, workspace.get());
HYTLASS_CHECK(status);
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
return 0;
}
int main(int argc, const char **argv) {
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
printf("%d x %d x %d tensor op Matrix Multiply\n", \
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
return run(options);
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hute/tensor.hpp"
#include "hute/atom/mma_atom.hpp"
#include "hytlass/numeric_types.h"
#include "hytlass/matrix_coord.h"
#include "hytlass/gemm/dispatch_policy.hpp"
#include "hytlass/gemm/group_array_problem_shape.hpp"
#include "hytlass/gemm/device/gemm_universal_adapter.h"
#include "hytlass/gemm/kernel/gemm_universal.hpp"
#include "hytlass/gemm/collective/collective_builder.hpp"
#include "hytlass/epilogue/collective/default_epilogue.hpp"
#include "hytlass/epilogue/collective/collective_builder.hpp"
#include "hytlass/epilogue/thread/linear_combination.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/packed_stride.hpp"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/device/tensor_compare.h"
#include "hytlass/util/reference/device/tensor_fill.h"
#include "helper.h"
using namespace hute;
/////////////////////////////////////////////////////////////////////////////////////////////////
using ElementA = hytlass::half_t;
using LayoutA = hytlass::layout::ColumnMajor;
constexpr int AlignmentA = 128 / hytlass::sizeof_bits<ElementA>::value;
using ElementB = hytlass::half_t;
using LayoutB = hytlass::layout::RowMajor;
constexpr int AlignmentB = 128 / hytlass::sizeof_bits<ElementB>::value;
using ElementC = hytlass::half_t;
using LayoutC = hytlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / hytlass::sizeof_bits<ElementC>::value;
using ElementAccumulator = float;
using ElementCompute = ElementAccumulator;
using TileShape_MNK = Shape<_128, _128, _32>;
using WarpShape_MNK = Shape<_2, _2, _1>;
using InstructionShape_MNK = Shape<_32, _32, _16>;
using ClusterShape_MNK = Shape<_1, _1, _1>;
using StageCountType = hytlass::gemm::collective::StageCount<2>;
// kernel to launch
using KernelScheduleType = hytlass::gemm::KernelPtrArraySpecialized;
using EpilogueTileType = hytlass::epilogue::collective::EpilogueTileAuto;
// epilogue to launch
using EpilogueScheduleType = hytlass::epilogue::PtrArrayNoSmemWarpSpecialized;
using TileSchedulerType = hytlass::gemm::PersistentScheduler;
using CollectiveMainloop = typename hytlass::gemm::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape_MNK, WarpShape_MNK, InstructionShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType
>::CollectiveOp;
using CollectiveEpilogue = typename hytlass::epilogue::collective::CollectiveBuilder<
hytlass::arch::Gfx928, hytlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
EpilogueScheduleType
>::CollectiveOp;
using GemmKernel = hytlass::gemm::kernel::GemmUniversal<
hytlass::gemm::ArrayProblemShape<Shape<int,int,int,int>>,
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType
>;
using Gemm = hytlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using DeviceGemmReference = hytlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_A;
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
uint64_t seed;
std::vector<int64_t> offset_A;
std::vector<int64_t> offset_B;
std::vector<int64_t> offset_C;
std::vector<int64_t> offset_D;
hytlass::DeviceAllocation<typename Gemm::ElementA> block_A;
hytlass::DeviceAllocation<typename Gemm::ElementB> block_B;
hytlass::DeviceAllocation<typename Gemm::ElementC> block_C;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_ref_D;
hytlass::DeviceAllocation<const typename Gemm::ElementA *> ptr_A;
hytlass::DeviceAllocation<const typename Gemm::ElementB *> ptr_B;
hytlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
hytlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help = false;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 10;
int m = 1024, n = 512, k = 1024, l = 10;
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "hute_ptr_array_batched_gemm\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> Sets the batch count for Ptr-Array GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << "hute_gfx928_ptr_array_batched_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};
/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
hytlass::Status status = hytlass::Status::kSuccess;
hipError_t error = hipSuccess;
bool passed = false;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to initialize a block of device data
template <class Element>
bool initialize_block(
hytlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {
Element scope_max, scope_min;
int bits_input = hytlass::sizeof_bits<Element>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else {
scope_max = 8;
scope_min = -8;
}
hytlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, scope_max, scope_min, 0);
return true;
}
/// Allocates device-side data
void allocate(const Options &options) {
int64_t total_elements_A = 0;
int64_t total_elements_B = 0;
int64_t total_elements_C = 0;
int64_t total_elements_D = 0;
for (int32_t i = 0; i < options.l; ++i) {
offset_A.push_back(total_elements_A);
offset_B.push_back(total_elements_B);
offset_C.push_back(total_elements_C);
offset_D.push_back(total_elements_D);
int64_t elements_A = options.m * options.k;
int64_t elements_B = options.k * options.n;
int64_t elements_C = options.m * options.n;
int64_t elements_D = options.m * options.n;
total_elements_A += elements_A;
total_elements_B += elements_B;
total_elements_C += elements_C;
total_elements_D += elements_D;
}
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
}
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(const Options &options) {
stride_A = hytlass::make_hute_packed_stride(StrideA{}, hute::make_shape(options.m, options.k, options.l));
stride_B = hytlass::make_hute_packed_stride(StrideB{}, hute::make_shape(options.n, options.k, options.l));
stride_C = hytlass::make_hute_packed_stride(StrideC{}, hute::make_shape(options.m, options.n, options.l));
stride_D = hytlass::make_hute_packed_stride(StrideD{}, hute::make_shape(options.m, options.n, options.l));
//
// Assign pointers
//
std::vector<ElementA *> ptr_A_host(options.l);
std::vector<ElementB *> ptr_B_host(options.l);
std::vector<ElementC *> ptr_C_host(options.l);
std::vector<ElementC *> ptr_D_host(options.l);
for (int32_t i = 0; i < options.l; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
}
ptr_A.reset(options.l);
ptr_A.copy_from_host(ptr_A_host.data());
ptr_B.reset(options.l);
ptr_B.copy_from_host(ptr_B_host.data());
ptr_C.reset(options.l);
ptr_C.copy_from_host(ptr_C_host.data());
ptr_D.reset(options.l);
ptr_D.copy_from_host(ptr_D_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
{
hytlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
// to use a GPU other than that with device ID 0.
hw_info.device_id = 0;
hw_info.sm_count = hytlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
hytlass::gemm::GemmUniversalMode::kArray,
{{options.m, options.n, options.k, options.l}},
{ptr_A.get(), stride_A, ptr_B.get(), stride_B},
{{options.alpha, options.beta}, ptr_C.get(), stride_C, ptr_D.get(), stride_D},
hw_info
};
return arguments;
}
bool verify(const Options &options) {
bool passed = true;
for (int32_t i = 0; i < options.l; ++i) {
hytlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({options.m, options.k}));
hytlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({options.k, options.n}));
hytlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({options.m, options.n}));
hytlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({options.m, options.n}));
//
// Compute reference output
//
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
{options.m, options.n, options.k},
ElementAccumulator(options.alpha),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ref_C,
ref_D);
// Wait for kernel to finish
HIP_CHECK(hipDeviceSynchronize());
// Check if output from HYTLASS kernel and reference kernel are equal or not
passed &= hytlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), options.m * options.n);
}
return passed;
}
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
{
allocate(options);
initialize(options);
// Instantiate HYTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check if the problem size is supported or not
HYTLASS_CHECK(gemm.can_implement(arguments));
// Initialize HYTLASS kernel with arguments and workspace pointer
HYTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
HYTLASS_CHECK(gemm.run());
// Check if output from HYTLASS kernel and reference kernel are equal or not
Result result;
result.passed = verify(options);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
if (!result.passed) {
exit(-1);
}
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
HYTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
HYTLASS_CHECK(gemm.run());
}
timer.stop();
// Compute average setup and runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl;
std::cout << " Batches : " << options.l << std::endl;
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
}
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate HYTLASS kernels
//
run<Gemm>(options);
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_tensorop_gemm_bias_relu
gfx928_tensorop_gemm_bias_relu.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
*/
#include <iostream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/epilogue/thread/linear_combination_relu.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hipError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
hytlass::Status status = hytlass::Status::kSuccess,
hipError_t error = hipSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::gemm::GemmCoord problem_size;
float alpha;
float beta;
bool reference_check;
int iterations;
int split_k_slices;
Options():
help(false),
problem_size({8192, 8192, 2048}),
reference_check(true),
iterations(10),
split_k_slices(1),
alpha(1),
beta()
{}
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("split_k_slices", split_k_slices);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "08_tensorop_fused_gemm example\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --split_k_slices=<int> Split-K factor to emulate\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/08_tensorop_fused_gemm/gfx928_tensorop_gemm_bias_relu --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 \n\n";
return out;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = hytlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = hytlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = hytlass::half_t; // <- data type of elements in output matrix D
// Note that if the output is column major, the bias has to be per row. i.e. every row has different bias.
// If the output is row major, the bias has to be per column, i.e. every column has different bias.
// Below list some other notices:
//
// Note this example only works for ColumnMajor output because
// 1) we only have row major epilogue.
// 2) we swap A and B if the output is column major then we can still use the
// row major epilogue.
// 3) Mx1 bias vector becomes 1xM after the swapping/transposing.
// 4) we can use the existing OutputIterator to load 1xM bias vector.
using LayoutInputA = hytlass::layout::ColumnMajor;
using LayoutInputB = hytlass::layout::ColumnMajor;
using LayoutOutput = hytlass::layout::ColumnMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
hytlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp = hytlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = hytlass::gemm::GemmShape<16, 16, 16>; // <- MMA Op tile M = 16, N = 8, K = 8
constexpr int kAlignmentA = 128 / hytlass::sizeof_bits<ElementInputA>::value;
constexpr int kAlignmentB = 128 / hytlass::sizeof_bits<ElementInputB>::value;
constexpr int kAlignmentC = 128 / hytlass::sizeof_bits<ElementOutput>::value;
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
//
// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij )
//
using EpilogueOp = hytlass::epilogue::thread::LinearCombinationRelu<
ElementOutput, // <- data type of output matrix
kAlignmentC, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue, // <- data type for alpha in linear combination function
hytlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias
// Number of pipelines you want to use
constexpr int NumStages = 1;
using Gemm = hytlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
kAlignmentA,
kAlignmentB>;
int run(Options &options) {
hytlass::gemm::GemmCoord problem_size = options.problem_size;
// Initialize tensors using HYTLASS helper functions
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
{problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// HYTLASS kernel
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Fill input and output matrices on host using HYTLASS helper functions
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_c_bias.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
hytlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
hytlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c_bias.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated HYTLASS kernel
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
{tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM
// to project away the N dimension by setting the stride to zero.
tensor_d.device_ref(), // <- reference to matrix D on device
{alpha}, // <- alpha
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate HYTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
hytlass::Status status = gemm_op.can_implement(arguments);
HYTLASS_CHECK(status);
// Initialize HYTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(status);
// Launch initialized HYTLASS kernel
status = gemm_op();
HYTLASS_CHECK(status);
Result result;
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMMs
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized HYTLASS kernel
status = gemm_op();
HYTLASS_CHECK(status);
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return -1;
}
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
//
// Create instantiation for device reference gemm kernel
//
if (options.reference_check) {
hytlass::reference::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue>
gemm_device_reference;
// Launch device reference to compute strictly the product A * B
gemm_device_reference(
problem_size,
alpha,
tensor_a.device_ref(),
tensor_b.device_ref(),
0,
tensor_ref_d.device_ref());
// Wait for kernels to finish
(void)hipDeviceSynchronize();
// Copy output data from HYTLASS and reference kernel to host for comparison
tensor_d.sync_host();
tensor_ref_d.sync_host();
// Compute bias + relu in host code
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < problem_size.n(); ++j) {
tensor_ref_d.at({i, j}) = std::max(
ElementOutput(0),
ElementOutput(tensor_ref_d.at({i, j}) + tensor_c_bias.at({i, 0}))
);
}
}
ElementOutput eps(0.05);
const ElementOutput non_zero_floor(1e-6f);
result.passed = hytlass::reference::host::TensorRelativelyEquals(tensor_ref_d.host_view(),
tensor_d.host_view(), eps, non_zero_floor);
}
if (result.passed) {
std::cout << "verify success " << std::endl;
}else{
std::cout << "verify failed " << std::endl;
}
return (result.passed ? 0 : -1);
}
int main(int argc, const char **argv) {
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
printf("%d x %d x %d tensor op Matrix Multiply\n", \
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
return run(options);
}
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_tensorop_conv2dfprop
gfx928_tensorop_conv2dfprop.cu
)
hytlass_example_add_executable(
gfx928_tensorop_conv2ddgrad
gfx928_tensorop_conv2ddgrad.cu
)
hytlass_example_add_executable(
gfx928_tensorop_conv2dwgrad
gfx928_tensorop_conv2dwgrad.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example shows how to compute 2d transposed convolution, also known as deconvolution, using HYTLASS
conv2d Dgrad kernels. Although two operations are computationaly equivalent, some care is needed to correctly
set up a problem size for HYTLASS.
In deep learning, transposed convolution is sometimes used for upscaling feature maps. This example
demonstrates the 2x upscaling case using the strided Dgrad kernel.
*/
#include <iostream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/tensor_ref.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_dgrad.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/device/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using hytlass::layout::TensorNHWC;
using hytlass::TensorRef;
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = float; // Data type of elements in output tensor
using ElementC = ElementOutput;
using ElementCompute = ElementComputeEpilogue;
using LayoutInputA = TensorNHWC;
using LayoutInputB = TensorNHWC;
using LayoutOutput = TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>;
// Number of pipelines you want to use
constexpr int NumStages = 2;
// This code section describe iterator algorithm selected is Analytic or Optimized
static hytlass::conv::IteratorAlgorithm const IteratorAlgorithm = hytlass::conv::IteratorAlgorithm::kOptimized;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementCompute, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementCompute>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
using Conv2dDgradKernel = typename hytlass::conv::kernel::DefaultConv2dDgrad<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementAccumulator, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
hytlass::conv::StrideSupport::kStrided // Use the strided Dgrad specialization
>::Kernel;
using ImplicitGemm = hytlass::conv::device::ImplicitGemmConvolution<Conv2dDgradKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 16),
padding(1, 1, 1, 1),
conv_stride(2, 2),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
alpha(1),
beta(0) {}
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors of hytlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("skip-ref-check")) {
reference_check = false;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
// Filter layout is CRSK
cmd.get_cmd_line_argument("k", filter_size.c());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.n() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.h() == 3 && filter_size.w() == 3) {
padding = {1, 1, 1, 1};
}
else {
filter_size.h() = 1;
filter_size.w() = 1;
padding = {0, 0, 0, 0};
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "34_transposed_conv2d example\n\n"
<< " This example shows how to compute 2d transposed convolution, also known as\n"
<< " deconvolution, using HYTLASS conv2d Dgrad kernels. Although two operations are\n"
<< " computationaly equivalent, some care is needed to correctly set up a problem size.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --skip-ref-check If set (true), skip reference check on the host\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/34_transposed_conv2d/34_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
// Here, out_pad corresponds to "output_padding" of conv2d_transpose op in deep learning frameworks.
// See for example https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
int out_pad_h = conv_stride.row() > 1 ? 1 : 0;
int out_pad_w = conv_stride.column() > 1 ? 1 : 0;
int out_h = (input_size.h() - 1) * conv_stride.row() - 2 * padding.n() + (((filter_size.h() - 1) * dilation.row() + 1)) + out_pad_h;
int out_w = (input_size.w() - 1) * conv_stride.column() - 2 * padding.w() + (((filter_size.w() - 1) * dilation.column() + 1)) + out_pad_w;
return hytlass::Tensor4DCoord(input_size.n(), out_h, out_w, filter_size.c());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NHWC * KRS
// Note that the input with the layout NHWC corresponds to the output from the perspective of dgrad,
// and that the filter layout is CRSK.
int64_t fmas = input_size.product() * int64_t(filter_size.h() * filter_size.w() * filter_size.n());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.c() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.conv_stride.row() << ","
<< options.conv_stride.column() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
// This is the same as Conv2dDgrad in tools/util/include/hytlass/util/reference/host/convolution.h,
// only variable names have been adapted for transposed conv2d.
void Conv2dTransposeReference(
hytlass::conv::Conv2dProblemSize problem_size,
TensorRef<ElementInputA, LayoutInputA> tensor_a,
TensorRef<ElementInputB, LayoutInputB> tensor_b,
TensorRef<ElementC, LayoutOutput> tensor_c,
TensorRef<ElementC, LayoutOutput> tensor_d,
ElementCompute alpha,
ElementCompute beta) {
int H = problem_size.P;
int W = problem_size.Q;
int P = problem_size.H;
int Q = problem_size.W;
int K = problem_size.C;
int C = problem_size.K;
for (int n = 0; n < problem_size.N; ++n) {
for (int p = 0; p < P; ++p) {
for (int q = 0; q < Q; ++q) {
for (int k = 0; k < K; ++k) {
ElementAccumulator acc = ElementAccumulator();
for (int r = 0; r < problem_size.R; ++r) {
for (int s = 0; s < problem_size.S; ++s) {
for (int c = 0; c < C; ++c) {
int filter_r = r;
int filter_s = s;
int h = p + problem_size.pad_h - filter_r * problem_size.dilation_h;
int w = q + problem_size.pad_w - filter_s * problem_size.dilation_w;
if (h >= 0 && (h % problem_size.stride_h) == 0 &&
w >= 0 && (w % problem_size.stride_w) == 0) {
h = h / problem_size.stride_h;
w = w / problem_size.stride_w;
if (h < H && w < W) {
ElementInputA a = tensor_a.at(hytlass::make_Coord(n, h, w, c));
ElementInputB b = tensor_b.at(hytlass::make_Coord(c, r, s, k));
acc += ElementAccumulator(a) * ElementAccumulator(b);
}
}
} // for (C)
} // for (S)
} // for (R)
// Apply Epilogue, compute ElementCompute, convert and store ElementC
ElementC c_ref = ElementC();
if (beta != ElementCompute()) {
c_ref = tensor_c.at(hytlass::make_Coord(n, p, q, k));
}
tensor_d.at(hytlass::make_Coord(n, p, q, k)) = alpha * ElementCompute(acc) + beta * ElementCompute(c_ref);
} // for (K)
} // for (W)
} // for (H)
} // for (N)
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
std::cout << "Output shape: " << options.output_size() << std::endl;
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(7),
ElementInputA(-8),
0);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(7),
ElementInputB(-8),
0);
// Fill tensor C and D on host with zeros
hytlass::reference::host::TensorFill(tensor_c.host_view());
hytlass::reference::host::TensorFill(tensor_d.host_view());
// Fill tensor D for reference on host with zeros
hytlass::reference::host::TensorFill(tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
//
// Define arguments for HYTLASS Convolution
//
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Construct Conv2dProblemSize with user defined output size
// The input in transposed conv2d corresponds to the output in the equivalent dgrad.
// Similarly for the output.
// Although the filter layout is CRSK from the perspective of conv2d transpose,
// the filter size does not need to change for setting up the problem size.
// There is no need to transpose the filter tensor either.
hytlass::conv::Conv2dProblemSize problem_size(
options.output_size(),
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.input_size,
mode
);
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta}
};
//
// Initialize HYTLASS Convolution
//
ImplicitGemm implicit_gemm;
size_t workspace_size = implicit_gemm.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
result.status = implicit_gemm();
HYTLASS_CHECK(result.status);
// // Skip reference check since there is no reference code for conv2d transpose in hytlass.
if (options.reference_check) {
tensor_d.sync_host();
std::cout << "Verification on host...\n";
Conv2dTransposeReference(problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_d.host_ref(),
options.alpha, options.beta);
bool passed = hytlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view());
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to run convolution kernels using functions and data structures
provided by HYTLASS using tensor cores.
Writing a single high performance convolution kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. HYTLASS solves this problem by providing simplified abstractions to compose
multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance
of GPU easily.
HYTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
and thread-block level, they compute on their own tile-size with higher level of tile sizes being
composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
threadblock-tile (tile size computed by a threadblock).
In thie example, we split variable initialization into
1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel
can view them (logical to physical mapping)
2. Setting up computation properties : describes how the above set tensors will be used to compute
output of convolution.
First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along
with alpha, beta as the equation for convolution is C = alpha * Conv(A, B) + beta * C. In HYTLASS,
the kernels first compute Conv(A, B) and leave the rest of the computation to end of the kernel as
alpha * X + beta * C is a simple element-wise operation on X (Conv(A, B)) and C. We call this as
epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to
ElementComputeEpilogue = float. We want to use MMA instructions on Turing and they support 4-bit
signed integer. But int4b_t is not fully supported by HYGON software stack, so HYTLASS introduces
hytlass::int4b_t. We use the data type for elements in input tensor A and B as hytlass::int4b_t. We
convey this to HYTLASS kernel by initializing template variables ElementAccumulator (int32_t),
ElementComputeEpilogue (float), ElementInputA (hytlass::int4b_t), ElementInputB (hytlass::int4b_t),
ElementOutput (int32_t). Communicating just the data type is not enough. As the data is laid out
linearly in memory, we have to convey the layout of tensors. We do that by initializing template
variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC hytlass variable. Next, we setup
rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template
variable EpilogueOp, which takes the data type of output ElementOutput (int32_t), the number of
elements per vector memory access (32), data type of accumulator (int32_t) and data type of
computation of linear combination (alpha * X + beta * C).
Now that we setup the properties of data, we have to setup properties of computation.
Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128,
64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate HYTLASS Implicit GEMM kernel, it
internally deduces the amount of threads needed per thread-block, amount of shared memory, storing
data in bank-conflict free manner, and ton of other variables required to compose, initialize and
launch a high performance Implicit GEMM kernel. This is the beauty of HYTLASS, it relieves developer
from understanding and coding complicated hardware optimizations which can easily go wrong.
HYTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines
constitute the whole process of loading input data from global memory to shared memory, loading data
from shared memory to registers, doing matrix multiplication, store to global memory. The below flow
sequence shows a typical mma pipeline.
tensor in global memory -> registers -> tile in shared memory -> registers -> mma -> registers ->
output to global memory
The problem with single pipeline is, each stage is synchronous which means, each stage has to wait
until the previous finished executing. There are stages in the pipeline which do not have fixed
latency, for example, the loads from global memory and shared memory. Therefore, we can add one more
pipeline with a phase shift in mma kernel to hide latency from global and shared memory loads.
Finally, the pipeline in a kernel looks like
(1) tensor in global memory -> (2) registers -> (3) tile in shared memory -> (4) registers -> (5)
mma -> (6) registers -> (7) output to global memory (1) <null> -> (2) <null> -> (3) tensor in global
memory -> (4) registers -> (5) tile in shared memory -> (6) registers -> (7) mma -> (8) registers ->
(9) output to global memory
This way, you can hide the second global memory load latency by doing computation on already loaded
input data.
There are few more template variables initialized such as, which threadblock tile of output matrix
is done which threadblock launched on an SM, GFX architecture of GPU you want to run on.
These are all put together to create a template variable which describes HYTLASS Implicit GEMM
kernel using hytlass::conv::device::ImplicitGemm template.
The next step is to initialize physical data, instantiate and initialize HYTLASS kernel and run it.
We use HYTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come
in the way of learning HYTLASS.
Once all the tensors are initialized and filled with data, create arguments tuple to launch HYTLASS
kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64,
R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the
important one, split k-dimension factor. Along with that, we query HYTLASS if any scratch-space
memory required by the kernel we instantiated. If yes, we create it and pass it along with other
arguments created to initialize HYTLASS kernel then, the kernel is launched.
In this example, we later on launch a reference convolution kernel (from HYTLASS utilities) to
compare if the output from HYTLASS kernel is same as the reference implicit GEMM kernel.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_fprop.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// 1 -> singlestage
// 2 -> pipelined
constexpr int NumStages = 1;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
using Conv2dFpropKernel = typename hytlass::conv::kernel::DefaultConv2dFprop<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::IteratorAlgorithm::kAnalytic,
hytlass::conv::StrideSupport::kStrided
>::Kernel;
using ImplicitGemm = hytlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
benchmark(false) { }
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors. Consequently,
// all pointers, strides, and tensor extents must be divisible by 32 elements.
//
int const kAlignment = 128 / hytlass::sizeof_bits<ElementInputA>::value;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
hytlass::Tensor4DCoord input_size,
hytlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
int p = 1;
int d = 1;
// padding
cmd.get_cmd_line_argument("p", p);
// strided
cmd.get_cmd_line_argument("d", d);
padding = {p, p, p, p};
conv_stride = {d, d};
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "gfx928_tensorop_conv2dfprop example\n\n"
<< " This example uses Turing's Tensor Core operators on int4 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --p=<int> Padding\n\n"
<< " --d=<int> Strided\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./gfx928_tensorop_conv2dfprop --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./gfx928_tensorop_conv2dfprop --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(5),
ElementInputA(-5),
5);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
2,
ElementInputB(5),
ElementInputB(-5),
5);
// Fill tensor C on host with zeros
hytlass::reference::host::TensorFill(
tensor_c.host_view());
// Fill tensor C for reference on host with zeros
hytlass::reference::host::TensorFill(
tensor_ref_c.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_ref_c.sync_device();
//
// Define arguments for HYTLASS Convolution
//
// mode (kCrossCorrelation or kConvolution)
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Construct Conv2dProblemSize with user defined output size
hytlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices);
// Construct ImplicitGemm::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
{options.alpha, options.beta},
};
//
// Initialize HYTLASS Convolution
//
ImplicitGemm implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
// Compute with reference implementation
hytlass::reference::host::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
ElementOutput,
hytlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_c.host_ref(),
options.alpha,
options.beta
);
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_c.sync_host();
const ElementOutput non_zero_floor(1e-6f);
ElementOutput eps(0.05);
bool passed = hytlass::reference::host::TensorRelativelyEquals(tensor_c.host_view(), tensor_ref_c.host_view(), eps, non_zero_floor);
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = hytlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "09_tensor_conv_workspace_conv2dfprop_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
struct Benchmark {
int h, w, c, k, r, s;
} layers[] = {
{56, 56, 64, 256, 1, 1},
{56, 56, 64, 64, 1, 1},
{56, 56, 64, 64, 3, 3},
{56, 56, 256, 64, 1, 1},
{56, 56, 256, 512, 1, 1},
{56, 56, 256, 128, 1, 1},
{28, 28, 128, 128, 3, 3},
{28, 28, 128, 512, 1, 1},
{28, 28, 512, 128, 1, 1},
{28, 28, 512, 1024, 1, 1},
{28, 28, 512, 256, 1, 1},
{14, 14, 256, 256, 3, 3},
{14, 14, 256, 1024, 1, 1},
{14, 14, 1024, 256, 1, 1},
{14, 14, 1024, 2048, 1, 1},
{14, 14, 1024, 512, 1, 1},
{7, 7, 512, 512, 3, 3},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to compute conv2d gradient with respect to weight (wgrad). Although conceptually similar to
forward convolution, wgrad accumulates contributions from input activations and output gradients to form updates to
the filter weights, and some care is needed to correctly set up a problem size for HYTLASS.
In deep learning, weight gradient computation is a key step during training, as it provides the parameter updates
required for learning. This example demonstrates a typical case of computing convolution weight gradients using
HYTLASS Wgrad kernels.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_wgrad.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/host/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = int32_t; // Data type of accumulator
using ElementComputeEpilogue = int32_t; // Data type of epilogue computation (alpha, beta)
using ElementInputA = int8_t; // Data type of elements in input tensor
using ElementInputB = int8_t; // Data type of elements in input tensor
using ElementOutput = int32_t; // Data type of elements in output tensor
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes GFX architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 32>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
constexpr int NumStages = 2;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
using DefaultConv2dWgrad = typename hytlass::conv::kernel::DefaultConv2dWgrad<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
hytlass::conv::IteratorAlgorithm::kOptimized,
hytlass::conv::StrideSupport::kUnity
>::Kernel;
using ImplicitGemm = hytlass::conv::device::ImplicitGemmConvolution<DefaultConv2dWgrad>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
benchmark(false) { }
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors. Consequently,
// all pointers, strides, and tensor extents must be divisible by 32 elements.
//
int const kAlignment = 128 / hytlass::sizeof_bits<ElementInputA>::value;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
hytlass::Tensor4DCoord input_size,
hytlass::Tensor4DCoord filter_size) {
this->input_size = input_size;
this->filter_size = filter_size;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.h() == 3 && filter_size.w() == 3) {
padding = {1, 1, 1, 1};
}
else {
filter_size.h() = 1;
filter_size.w() = 1;
padding = {0, 0, 0, 0};
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "09_hytlass_tensorop_conv2d example\n\n"
<< " This example uses Turing's Tensor Core operators on int4 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./gfx928_tensorop_conv2dwgrad --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./gfx928_tensorop_conv2dwgrad --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.output_size());
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.input_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_c(options.filter_size);
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(5),
ElementInputA(-5),
5);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
2,
ElementInputB(5),
ElementInputB(-5),
5);
// Fill tensor C on host with zeros
hytlass::reference::host::TensorFill(
tensor_c.host_view());
// Fill tensor C for reference on host with zeros
hytlass::reference::host::TensorFill(
tensor_ref_c.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_ref_c.sync_device();
//
// Define arguments for HYTLASS Convolution
//
// mode (kCrossCorrelation or kConvolution)
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Construct Conv2dProblemSize with user defined output size
hytlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices);
// Construct ImplicitGemm::Argument structure with conv2d
// problem size, data pointers, and epilogue values
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
{options.alpha, options.beta},
};
//
// Initialize HYTLASS Convolution
//
ImplicitGemm implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_op.can_implement(arguments);
HYTLASS_CHECK(result.status);
result.status = implicit_gemm_op.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on host...\n";
// Compute with reference implementation
hytlass::reference::host::Conv2dWgrad<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
ElementOutput,
hytlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_a.host_ref(),
tensor_b.host_ref(),
tensor_c.host_ref(),
tensor_ref_c.host_ref(),
options.alpha,
options.beta
);
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_c.sync_host();
const ElementOutput non_zero_floor(0);
ElementOutput eps(0);
bool passed = hytlass::reference::host::TensorRelativelyEquals(tensor_c.host_view(), tensor_ref_c.host_view(), eps, non_zero_floor);
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = hytlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "09_tensor_conv_workspace_conv2dwgrad_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_c.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_op();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {1, 32, 64, 128, 256, 512};
struct Benchmark {
int h, w, c, k, r, s;
} layers[] = {
{56, 56, 64, 256, 1, 1},
{56, 56, 64, 64, 1, 1},
{56, 56, 64, 64, 3, 3},
{56, 56, 256, 64, 1, 1},
{56, 56, 256, 512, 1, 1},
{56, 56, 256, 128, 1, 1},
{28, 28, 128, 128, 3, 3},
{28, 28, 128, 512, 1, 1},
{28, 28, 512, 128, 1, 1},
{28, 28, 512, 1024, 1, 1},
{28, 28, 512, 256, 1, 1},
{14, 14, 256, 256, 3, 3},
{14, 14, 256, 1024, 1, 1},
{14, 14, 1024, 256, 1, 1},
{14, 14, 1024, 2048, 1, 1},
{14, 14, 1024, 512, 1, 1},
{7, 7, 512, 512, 3, 3},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.h, layer.w, layer.c}, {layer.k, layer.r, layer.s, layer.c});
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_tensorop_conv2dwgrad_split_k
gfx928_tensorop_conv2dwgrad_split_k.cu
)
/***************************************************************************************************
* Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example shows how to compute conv2d gradient with respect to weight (wgrad). In wgrad, the K dimension of
impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q). Split-k with parallel
reduction is highly effective for such cases. Given split_k_slices parameter, it partitions the K loop into
split_k_slices chunks and computes partial reductions in parallel across different blocks. After that,
a parallel reduction kernel is launched to accumulate partial reductions.
In practice, wgrad requires fp32 accumulation to avoid overflow. When the input is fp16, some care is needed
to correctly instantiate the GEMM template.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "hytlass/hytlass.h"
#include "hytlass/gemm/device/gemm.h"
#include "hytlass/conv/kernel/default_conv2d_wgrad.h"
#include "hytlass/conv/device/implicit_gemm_convolution.h"
#include "hytlass/util/command_line.h"
#include "hytlass/util/host_tensor.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/util/reference/device/gemm.h"
#include "hytlass/util/reference/host/tensor_compare.h"
#include "hytlass/util/reference/host/tensor_copy.h"
#include "hytlass/util/reference/host/tensor_fill.h"
#include "hytlass/util/reference/device/convolution.h"
#include "hytlass/util/tensor_view_io.h"
#include "hytlass/reduction/device/reduce_split_k.h"
#include "hytlass/reduction/thread/reduction_operators.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
// In Wgrad, fp32 accumulation is necessary in practice.
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = hytlass::half_t; // Data type of elements in input tensor
using ElementInputB = hytlass::half_t; // Data type of elements in input tensor
using ElementOutput = hytlass::half_t; // Data type of elements in output tensor
using ElementC = ElementOutput;
using ElementCompute = ElementComputeEpilogue;
using LayoutInputA = hytlass::layout::TensorNHWC;
using LayoutInputB = hytlass::layout::TensorNHWC;
using LayoutOutput = hytlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = hytlass::arch::OpClassTensorOp;
// This code section describes hip Gfx architecture number
using SmArch = hytlass::arch::Gfx928;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = hytlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = hytlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = hytlass::gemm::GemmShape<16, 16, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = hytlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
constexpr int NumStages = 2;
// This code section describe iterator algorithm selected is Analytic or Optimized
static hytlass::conv::IteratorAlgorithm const IteratorAlgorithm = hytlass::conv::IteratorAlgorithm::kOptimized;
// We need two epilogue functors - one for GEMM and another for the final reduction.
// The epilogue for GEMM is not used, but needed to instantiate the HYTLASS kernel template.
// Note that, when the input is fp16 and accumulation is fp32, the output of GEMM needs to be fp32,
// the final reduction is done in fp32, and the reduction epilogue converts fp32 outputs to fp16.
// Therefore, the output type of the GEMM epilogue is ElementCompute, not ElementOutput.
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOpGEMM = hytlass::epilogue::thread::LinearCombination<
ElementCompute, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementCompute>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
// The epilogue functor for reduction. This is the one that is actually used.
using EpilogueOpReduction = hytlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / hytlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in lin
using Conv2dWgradKernel = typename hytlass::conv::kernel::DefaultConv2dWgrad<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementAccumulator, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOpGEMM,
SwizzleThreadBlock,
NumStages,
hytlass::arch::OpMultiplyAdd,
IteratorAlgorithm
>::Kernel;
using ImplicitGemm = hytlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
using EpilogueOutputOp = EpilogueOpReduction;
/// Reduction kernel
using ReductionOp = hytlass::reduction::thread::ReduceAdd<
ElementAccumulator,
typename EpilogueOutputOp::ElementAccumulator,
EpilogueOutputOp::kCount
>;
using ReductionKernel = hytlass::reduction::kernel::ReduceSplitK<
hytlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>;
using ReductionDevice = hytlass::reduction::device::ReduceSplitK<ReductionKernel>;
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
hytlass::Tensor4DCoord input_size;
hytlass::Tensor4DCoord filter_size;
hytlass::Tensor4DCoord padding;
hytlass::MatrixCoord conv_stride;
hytlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
int split_k_slices;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(true),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
split_k_slices(8),
benchmark(false)
{}
// Verify the problem size is compatible with the HYTLASS Convolution implementation.
bool valid() {
//
// HYTLASS attempts to load 128b vectors of hytlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
hytlass::Tensor4DCoord input_size,
hytlass::Tensor4DCoord filter_size,
hytlass::MatrixCoord stride) {
this->input_size = input_size;
this->filter_size = filter_size;
conv_stride = stride;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
hytlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("split-k-slices", split_k_slices);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.h() == 3 && filter_size.w() == 3) {
padding = {1, 1, 1, 1};
}
else {
filter_size.h() = 1;
filter_size.w() = 1;
padding = {0, 0, 0, 0};
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "10_hytlass_tensorop_wgrad_split_k example\n\n"
<< " This example shows how to compute conv2d gradient with respect to weight (wgrad).\n"
<< " In wgrad, the K dimension of impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q).\n"
<< " Split-k with parallel reduction is highly effective for such cases.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --split-k-slices=<int> Split-k factor \n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/10_hytlass_tensorop_wgrad_split_k/gfx928_tensorop_conv2dwgrad_split_k --n=32 --h=224 --w=224 --c=128 --k=256 --r=3 --s=3 --split-k-slices=8\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
hytlass::Tensor4DCoord output_size() const {
return hytlass::Tensor4DCoord(input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
hytlass::Status status;
hytlass::Status reference_check;
hipError_t error;
Result():
runtime_ms(0),
gflops(0),
status(hytlass::Status::kSuccess),
reference_check(hytlass::Status::kInvalid),
error(hipSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.conv_stride.row() << ","
<< options.conv_stride.column() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the HYTLASS Utilities.
//
// Inputs are the output gradient and the original activation.
hytlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.output_size());
hytlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.input_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.filter_size);
hytlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.filter_size);
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(7),
ElementInputA(-8),
0);
// Fill tensor B on host with uniform-distribution random data
hytlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(7),
ElementInputB(-8),
0);
// Fill tensor C, D on host with zeros
hytlass::reference::host::TensorFill(tensor_c.host_view());
hytlass::reference::host::TensorFill(tensor_d.host_view());
// Fill tensor D for reference on host with zeros
hytlass::reference::host::TensorFill(tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for HYTLASS Convolution
//
hytlass::conv::Mode mode = hytlass::conv::Mode::kCrossCorrelation;
// Partition the GEMM K loop into split_k_slices chunks
int split_k_slices = options.split_k_slices;
// Construct Conv2dProblemSize with user defined output size
// Do not forget to pass the last argument.
hytlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices
);
using hytlass::layout::TensorNHWC;
hytlass::conv::SplitKMode const split_k_mode = hytlass::conv::SplitKMode::kParallel;
// Since the epilogue is not computed after GEMM, there is no need to pass the C tensor and
// alpha and beta can be set to 1 and 0 respectively.
// Moreover, since the output will be written to the workspace, there is no need to pass
// the D tensor as well.
// Do not forget to pass the last argument.
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
{nullptr, TensorNHWC()},
{nullptr, TensorNHWC()},
{ElementCompute(1), ElementCompute(0)},
split_k_mode
};
//
// Initialize HYTLASS Convolution
//
ImplicitGemm implicit_gemm;
size_t workspace_size = implicit_gemm.get_workspace_size(arguments);
// Split-K requires non-zero workspace size. The workspace size grows linearly with split_k_slices.
std::cout << "split-k-slices: " << split_k_slices << std::endl;
std::cout << "workspace size: " << workspace_size << std::endl;
// Allocate workspace memory
hytlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm.can_implement(arguments);
HYTLASS_CHECK(result.status);
// After the workspace is allocated, we point the GEMM destination pointer to the workspace.
TensorNHWC layout_D{TensorNHWC::packed(options.filter_size)};
arguments.ref_D.reset(reinterpret_cast<ElementCompute*>(workspace.get()), layout_D);
result.status = implicit_gemm.initialize(arguments, workspace.get());
HYTLASS_CHECK(result.status);
//
// Launch initialized HYTLASS kernel
//
result.status = implicit_gemm();
HYTLASS_CHECK(result.status);
if (split_k_mode == hytlass::conv::SplitKMode::kParallel) {
// Do reduction
ReductionDevice reduction_op;
auto& status = result.status;
static hytlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator;
typename ReductionDevice::Arguments reduction_args(
hytlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
problem_size.split_k_slices,
hytlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
// Reduction input
{
reinterpret_cast<ElementAccumulator*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
// Destination
{
tensor_d.device_data(),
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
// Source
{
tensor_c.device_data(),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
{options.alpha, options.beta}
);
status = reduction_op.initialize(reduction_args, nullptr);
status = reduction_op();
}
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on device...\n";
// Compute with reference implementation
hytlass::reference::device::Conv2dWgrad<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
hytlass::NumericConverter<ElementOutput, ElementComputeEpilogue>>(
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_ref_d.device_ref(),
options.alpha,
options.beta);
// Check if output from HYTLASS kernel and reference kernel are equal or not
tensor_c.sync_host();
tensor_d.sync_host();
tensor_ref_d.sync_host();
const ElementOutput non_zero_floor(1e-6f);
ElementOutput eps(0.05);
bool passed = hytlass::reference::host::TensorRelativelyEquals(tensor_d.host_view(), tensor_ref_d.host_view(), eps, non_zero_floor);
if (!passed) {
result.reference_check = hytlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = hytlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = hytlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "10_hytlass_tensorop_wgrad_split_k"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
hipEvent_t events[2];
for (auto & event : events) {
result.error = hipEventCreate(&event);
if (result.error != hipSuccess) {
std::cerr << "hipEventCreate() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = hipEventRecord(events[0]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm();
HYTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = hipEventRecord(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventRecord() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = hipEventSynchronize(events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventSynchronize() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = hipEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != hipSuccess) {
std::cerr << "hipEventElapsed() failed: " << hipGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)hipEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {34, 408};
struct Benchmark {
int h, w, c, k, r, s, stride_h, stride_w;
} layers[] = {
{56, 56, 64, 256, 1, 1, 1, 1},
{56, 56, 64, 64, 1, 1, 1, 1},
{56, 56, 64, 64, 3, 3, 1, 1},
{56, 56, 256, 64, 1, 1, 1, 1},
{56, 56, 256, 512, 1, 1, 2, 2},
{56, 56, 256, 128, 1, 1, 1, 1},
{56, 56, 128, 128, 3, 3, 2, 2},
{28, 28, 128, 512, 1, 1, 1, 1},
{28, 28, 512, 128, 1, 1, 1, 1},
{28, 28, 128, 128, 3, 3, 1, 1},
{28, 28, 512, 1024, 1, 1, 2, 2},
{28, 28, 512, 256, 1, 1, 1, 1},
{28, 28, 256, 256, 3, 3, 2, 2},
{14, 14, 256, 1024, 1, 1, 1, 1},
{14, 14, 1024, 256, 1, 1, 1, 1},
{14, 14, 256, 256, 3, 3, 1, 1},
{14, 14, 1024, 2048, 1, 1, 2, 2},
{14, 14, 1024, 512, 1, 1, 1, 1},
{14, 14, 512, 512, 3, 3, 2, 2},
{ 7, 7, 512, 2048, 1, 1, 1, 1},
{ 7, 7, 2048, 512, 1, 1, 1, 1},
{ 7, 7, 512, 512, 3, 3, 1, 1},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.h, layer.w, layer.c},
{layer.k, layer.r, layer.s, layer.c},
{layer.stride_h, layer.stride_w});
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
hytlass_example_add_executable(
gfx928_tensorop_group_conv2dfprop
gfx928_tensorop_group_conv2dfprop.cu
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment