Unverified Commit 60b13256 authored by Babak Hejazi's avatar Babak Hejazi Committed by GitHub
Browse files

Benchmark - Support autotuning in cublaslt gemm (#706)

**Description**
Enable autotuning as an opt-in mode when benchmarking cublasLt via
`cublaslt_gemm`

The implementation is based on
https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemmSimpleAutoTuning/sample_cublasLt_LtSgemmSimpleAutoTuning.cu

The behavior of original benchmark command remains unchanged, e.g.:
- `cublaslt_gemm -m 2048 -n 12288 -k 1536 -w10000 -i 1000 -t fp8e4m3`

The new opt-in options are `-a` (for autotune) and `-I` (for autotune
iterations, default is 50, same as the default for `-i`) and `-W` (for
autotune warmups, default=20, same as the default for `-w`), e.g.:
- `cublaslt_gemm -m 2048 -n 12288 -k 1536 -w 10000 -i 1000 -t fp8e4m3
-a`
- `cublaslt_gemm -m 2048 -n 12288 -k 1536 -w 10000 -i 1000 -t fp8e4m3 -a
-I 10 -W 10`

**Note:** This PR also changes the default `gemm_compute_type` for BF16
and FP16 to `CUBLAS_COMPUTE_32F`.

**Further observations:** 
1. The support matrix of the `cublaslt_gemm` could be further extended
in the future to support non-FP16 output as well for FP8 inputs.
2. Currently, the input matrices are initialized with values of 1.0 and
2.0 which makes them less demanding in terms of power. Another future
extension could be to enable another fill mode for, say, uniform random
numbers between -1 and 1.
3. cuBLAS workspace recommendations are listed under
https://docs.nvidia.com/cuda/cublas/#cublassetworkspace



Update (June 10, 2025): verified using higher level test driver with
these commands:

1. inline:
```
python3 -c "                                                                            
from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger

parameters = (
    '--num_warmup 10 --num_steps 50 '
    '--shapes 512,512,512 1024,1024,1024 --in_types fp16 fp32 '
    '--enable_autotune --num_warmup_autotune 20 --num_steps_autotune 50'
)
context = BenchmarkRegistry.create_benchmark_context(
    'cublaslt-gemm', platform=Platform.CUDA, parameters=parameters
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
logger.info('Result: {}'.format(benchmark.result))
"
```

2. newly added script: 
`python3 examples/benchmarks/cublaslt_function.py`

---------
Co-authored-by: default avatarBabak Hejazi <babakh@nvidia.com>
parent 0b8d1fd4
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Micro benchmark example for cuBLASLt GEMM performance benchmark.
Commands to run:
python3 examples/benchmarks/cublaslt_function.py
"""
from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger
if __name__ == '__main__':
# Basic usage without autotune
print('Running cuBLASLt benchmark without autotune...')
parameters = '--num_warmup 10 --num_steps 50 --shapes 512,512,512 --in_types fp16 fp32'
context = BenchmarkRegistry.create_benchmark_context('cublaslt-gemm', platform=Platform.CUDA, parameters=parameters)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
# Enhanced usage with autotune enabled
print('\nRunning cuBLASLt benchmark with autotune enabled...')
parameters_autotune = (
'--num_warmup 10 --num_steps 50 '
'--shapes 512,512,512 1024,1024,1024 --in_types fp16 fp32 '
'--enable_autotune --num_warmup_autotune 20 --num_steps_autotune 50'
)
context_autotune = BenchmarkRegistry.create_benchmark_context(
'cublaslt-gemm', platform=Platform.CUDA, parameters=parameters_autotune
)
benchmark_autotune = BenchmarkRegistry.launch_benchmark(context_autotune)
if benchmark_autotune:
logger.info(
'benchmark with autotune: {}, return code: {}, result: {}'.format(
benchmark_autotune.name, benchmark_autotune.return_code, benchmark_autotune.result
)
)
# FP8 specific usage with autotune
print('\nRunning cuBLASLt benchmark with FP8 and autotune...')
parameters_fp8 = (
'--num_warmup 5 --num_steps 20 '
'--shapes 512,512,512 --in_types fp8e4m3 fp8e5m2 '
'--enable_autotune --num_warmup_autotune 10 --num_steps_autotune 30'
)
context_fp8 = BenchmarkRegistry.create_benchmark_context(
'cublaslt-gemm', platform=Platform.CUDA, parameters=parameters_fp8
)
benchmark_fp8 = BenchmarkRegistry.launch_benchmark(context_fp8)
if benchmark_fp8:
logger.info(
'FP8 benchmark with autotune: {}, return code: {}, result: {}'.format(
benchmark_fp8.name, benchmark_fp8.return_code, benchmark_fp8.result
)
)
...@@ -36,6 +36,26 @@ def add_parser_arguments(self): ...@@ -36,6 +36,26 @@ def add_parser_arguments(self):
required=False, required=False,
help='List of input data types, support {}.'.format(' '.join(self._in_types)), help='List of input data types, support {}.'.format(' '.join(self._in_types)),
) )
self._parser.add_argument(
'--enable_autotune',
action='store_true',
required=False,
help='Enable exhaustive autotune mode to find best algorithm.',
)
self._parser.add_argument(
'--num_warmup_autotune',
type=int,
default=20,
required=False,
help='Number of warm up steps for autotune.',
)
self._parser.add_argument(
'--num_steps_autotune',
type=int,
default=50,
required=False,
help='Number of steps to measure for autotune.',
)
def _preprocess(self): def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking. """Preprocess/preparation operations before the benchmarking.
...@@ -50,9 +70,16 @@ def _preprocess(self): ...@@ -50,9 +70,16 @@ def _preprocess(self):
self._commands = [] self._commands = []
for _m, _n, _k, _b, _in_type in self._shapes_to_run: for _m, _n, _k, _b, _in_type in self._shapes_to_run:
# pull out the autotune args onto their own short f-string
autotune_args = (
f' -a -W {self._args.num_warmup_autotune}'
f' -I {self._args.num_steps_autotune}'
) if self._args.enable_autotune else ''
self._commands.append( self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} ' f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}' f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
f'{(" " + autotune_args) if autotune_args else ""}'
) )
return True return True
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
set(CMAKE_CUDA_RUNTIME_LIBRARY SHARED)
project(cublaslt_gemm LANGUAGES CXX) project(cublaslt_gemm LANGUAGES CXX)
find_package(CUDAToolkit QUIET) find_package(CUDAToolkit QUIET)
......
...@@ -25,16 +25,24 @@ struct Args { ...@@ -25,16 +25,24 @@ struct Args {
int batch = 0; int batch = 0;
int warmup = 20; int warmup = 20;
int iter = 50; int iter = 50;
// Default warmup iterations for autotune
int warmup_autotune = 20;
// Default repeat iterations for autotune
int iter_autotune = 50;
std::string in_type = "fp8e4m3"; std::string in_type = "fp8e4m3";
bool autotune = false;
}; };
void process_args(int argc, char **argv, Args *args) { void process_args(int argc, char **argv, Args *args) {
const char *const short_opts = "m:n:k:b:w:i:t:"; const char *const short_opts = "m:n:k:b:w:i:t:aI:W:";
const option long_opts[] = { const option long_opts[] = {
{"batch", required_argument, nullptr, 'b'}, {"batch", required_argument, nullptr, 'b'},
{"warmup", required_argument, nullptr, 'w'}, {"warmup", required_argument, nullptr, 'w'},
{"iter", required_argument, nullptr, 'i'}, {"iter", required_argument, nullptr, 'i'},
{"in_type", required_argument, nullptr, 't'}, {"in_type", required_argument, nullptr, 't'},
{"autotune", no_argument, nullptr, 'a'},
{"iter-autotune", required_argument, nullptr, 'I'},
{"warmup-autotune", required_argument, nullptr, 'W'},
}; };
int opt = 0; int opt = 0;
...@@ -61,6 +69,15 @@ void process_args(int argc, char **argv, Args *args) { ...@@ -61,6 +69,15 @@ void process_args(int argc, char **argv, Args *args) {
case 't': case 't':
args->in_type = std::string(optarg); args->in_type = std::string(optarg);
break; break;
case 'a':
args->autotune = true;
break;
case 'I':
args->iter_autotune = std::stoi(optarg);
break;
case 'W':
args->warmup_autotune = std::stoi(optarg);
break;
} }
} }
} }
...@@ -91,7 +108,8 @@ template <typename T> cudaDataType_t get_datatype() { ...@@ -91,7 +108,8 @@ template <typename T> cudaDataType_t get_datatype() {
} }
template <typename Ta, typename Tb, typename Tout> template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter) { float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter, bool autotune,
int iter_autotune, int warmup_autotune) {
// init matrix // init matrix
Ta *matrix_a = nullptr; Ta *matrix_a = nullptr;
Tb *matrix_b = nullptr; Tb *matrix_b = nullptr;
...@@ -112,7 +130,16 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i ...@@ -112,7 +130,16 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i
CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT); CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT);
void *workspace = nullptr; void *workspace = nullptr;
size_t workspace_size = gemm->GetAlgorithm(1, 2 * m * n); size_t workspace_size;
if (autotune) {
workspace_size = gemm->GetAlgorithmExhaustive(
8, 2 * m * n, 1.0f, 0.0f, reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), iter_autotune, warmup_autotune);
} else {
workspace_size = gemm->GetAlgorithm(1, 2 * m * n);
}
cudaMalloc(&workspace, workspace_size); cudaMalloc(&workspace, workspace_size);
// timer // timer
...@@ -142,8 +169,9 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i ...@@ -142,8 +169,9 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i
return (time * 1e3 / iter); return (time * 1e3 / iter);
} }
template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(Args *args) { template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(const Args *args) {
float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter); float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter,
args->autotune, args->iter_autotune, args->warmup_autotune);
// m n k batch time_us tflops // m n k batch time_us tflops
printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us, printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1)); float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1));
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#include "cublaslt_utils.h" #include "cublaslt_utils.h"
#include <algorithm> // for std::sort
#include <cassert> // for assert
void cublasLtGemm::Init() { void cublasLtGemm::Init() {
cublasLtHandle_t handle; cublasLtHandle_t handle;
...@@ -20,6 +22,11 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -20,6 +22,11 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
void *a_scale_inverse, /* only need to be set for fp8 */ void *a_scale_inverse, /* only need to be set for fp8 */
void *b_scale_inverse /* only need to be set for fp8 */ void *b_scale_inverse /* only need to be set for fp8 */
) { ) {
// Store dimensions
m_ = m;
n_ = n;
k_ = k;
cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr; cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
// force c_type // force c_type
cudaDataType_t c_type = d_type; cudaDataType_t c_type = d_type;
...@@ -33,7 +40,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -33,7 +40,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
// strided batch gemm // strided batch gemm
if (batch > 0) { if (batch > 0) {
int64_t stridea = m * k, strideb = k * n, stridec = m * n, strided = m * n; int64_t stridea = static_cast<int64_t>(m) * k, strideb = static_cast<int64_t>(k) * n,
stridec = static_cast<int64_t>(m) * n, strided = static_cast<int64_t>(m) * n;
CUBLAS_CHECK( CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
...@@ -56,23 +64,35 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -56,23 +64,35 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
c_desc_.reset(c_desc); c_desc_.reset(c_desc);
d_desc_.reset(d_desc); d_desc_.reset(d_desc);
// default to tf32 except for e5m2 inputs where the config is not supported // Set compute type and scale type based on input types
cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; cublasComputeType_t gemm_compute_type;
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) cudaDataType_t scale_type;
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
gemm_compute_type = CUBLAS_COMPUTE_32F; gemm_compute_type = CUBLAS_COMPUTE_32F;
if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) scale_type = CUDA_R_32F;
} else if (a_type == CUDA_R_16F || b_type == CUDA_R_16F || a_type == CUDA_R_16BF || b_type == CUDA_R_16BF) {
gemm_compute_type = CUBLAS_COMPUTE_32F;
scale_type = CUDA_R_32F;
} else if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) {
gemm_compute_type = CUBLAS_COMPUTE_64F; gemm_compute_type = CUBLAS_COMPUTE_64F;
if (a_type == CUDA_R_8I) scale_type = CUDA_R_64F;
} else if (a_type == CUDA_R_8I) {
gemm_compute_type = CUBLAS_COMPUTE_32I; gemm_compute_type = CUBLAS_COMPUTE_32I;
scale_type = CUDA_R_32I;
} else {
gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
scale_type = CUDA_R_32F;
}
cublasLtMatmulDesc_t op_desc = nullptr; cublasLtMatmulDesc_t op_desc = nullptr;
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F)); CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, scale_type));
op_desc_.reset(op_desc); op_desc_.reset(op_desc);
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) { if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
// disable fastAccuMode, set to 0
int8_t fastAccuMode = 1; int8_t fastAccuMode = 1;
cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode)); CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode,
sizeof(fastAccuMode)));
} }
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
...@@ -93,10 +113,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -93,10 +113,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) { size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size))); &max_workspace_size, sizeof(max_workspace_size)));
int found_algorithm_count = 0; int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count); std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
// Though we query all of possible algorithm, we will use the first later
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count, c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
results.data(), &found_algorithm_count)); results.data(), &found_algorithm_count));
...@@ -109,15 +127,120 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_ ...@@ -109,15 +127,120 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_
return heuristic_results_.front().workspaceSize; return heuristic_results_.front().workspaceSize;
} }
size_t cublasLtGemm::GetAlgorithmExhaustive(int max_algorithm_count, size_t max_workspace_size, float alpha, float beta,
void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d,
int repeat_iterations, int warmup_iterations) {
// Set workspace size in preference
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size)));
// Get heuristic algorithms
int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
results.data(), &found_algorithm_count));
if (found_algorithm_count == 0) {
throw std::runtime_error("Unable to find any suitable algorithms");
}
results.resize(found_algorithm_count);
heuristic_results_ = std::move(results);
// Create stream and events for timing
cudaStream_t stream;
cudaEvent_t startEvent, stopEvent;
cudaStreamCreate(&stream);
cudaEventCreate(&startEvent);
cudaEventCreate(&stopEvent);
// Test each algorithm multiple times to find the best one
std::vector<float> algoTimes(repeat_iterations);
// Allocate workspace
void *workspace = nullptr;
cudaMalloc(&workspace, max_workspace_size);
// Test each algorithm
algo_metrics_.clear();
algo_metrics_.reserve(found_algorithm_count);
for (int algoIdx = 0; algoIdx < found_algorithm_count; algoIdx++) {
// Skip algorithms that require more workspace than available
if (heuristic_results_[algoIdx].workspaceSize > max_workspace_size) {
continue;
}
// warmup
for (int warmupIdx = 0; warmupIdx < warmup_iterations; warmupIdx++) {
cublasStatus_t status =
cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
&beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
&heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);
}
// Test each algorithm multiple times
cudaEventRecord(startEvent, stream);
for (int checkIdx = 0; checkIdx < repeat_iterations; checkIdx++) {
cublasStatus_t status =
cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
&beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
&heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);
// Skip if algorithm fails
if (status != CUBLAS_STATUS_SUCCESS) {
algoTimes[checkIdx] = std::numeric_limits<float>::max();
continue;
}
}
cudaEventRecord(stopEvent, stream);
cudaEventSynchronize(stopEvent);
float time = 0;
cudaEventElapsedTime(&time, startEvent, stopEvent);
algoTimes[algoIdx] = time / repeat_iterations;
float meanTime = algoTimes[algoIdx];
float flops = 2.0f * m_ * n_ * k_ / (meanTime * 1e-3f);
// Store metrics
AlgorithmMetrics metrics;
metrics.algo = heuristic_results_[algoIdx].algo;
metrics.workspace_size = heuristic_results_[algoIdx].workspaceSize;
metrics.time = meanTime;
metrics.flops = flops;
algo_metrics_.push_back(metrics);
}
std::sort(algo_metrics_.begin(), algo_metrics_.end(),
[](const AlgorithmMetrics &a, const AlgorithmMetrics &b) { return a.time < b.time; });
if (!algo_metrics_.empty())
heuristic_results_[0].algo = algo_metrics_.front().algo;
// Clean up resources
cudaFree(workspace);
cudaEventDestroy(startEvent);
cudaEventDestroy(stopEvent);
cudaStreamDestroy(stream);
if (!algo_metrics_.empty()) {
return algo_metrics_.front().workspace_size;
}
throw std::runtime_error("No valid algorithms found during autotune");
}
void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta, void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream) { void *workspace, size_t workspace_size, cudaStream_t stream) {
CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */ CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */ matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */ a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */ b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */ matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */ c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */ d_desc_.get(), &heuristic_results_.front().algo, workspace, /* workspace */
workspace, /* workspace */
workspace_size, stream)); /* stream */ workspace_size, stream)); /* stream */
} }
...@@ -51,9 +51,21 @@ class cublasLtGemm { ...@@ -51,9 +51,21 @@ class cublasLtGemm {
size_t GetAlgorithm(int max_algorithm_count, size_t max_workspace_size); size_t GetAlgorithm(int max_algorithm_count, size_t max_workspace_size);
size_t GetAlgorithmExhaustive(int max_algorithm_count, size_t max_workspace_size, float alpha, float beta,
void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d,
int repeat_iterations = 100, int warmup_iterations = 100);
void Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta, void Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream); void *workspace, size_t workspace_size, cudaStream_t stream);
// Type to store algorithm performance metrics
struct AlgorithmMetrics {
cublasLtMatmulAlgo_t algo;
size_t workspace_size;
float time;
float flops;
};
private: private:
UniqueHandle handle_; UniqueHandle handle_;
UniqueOpDesc op_desc_; UniqueOpDesc op_desc_;
...@@ -63,4 +75,10 @@ class cublasLtGemm { ...@@ -63,4 +75,10 @@ class cublasLtGemm {
UniqueLayoutDesc d_desc_; UniqueLayoutDesc d_desc_;
UniqueMatmulPreference preference_; UniqueMatmulPreference preference_;
std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results_; std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results_;
std::vector<AlgorithmMetrics> algo_metrics_;
cublasComputeType_t compute_type_ = CUBLAS_COMPUTE_32F;
cudaDataType_t scale_type_ = CUDA_R_32F;
int m_ = 0;
int n_ = 0;
int k_ = 0;
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment