Unverified Commit 60b13256 authored by Babak Hejazi's avatar Babak Hejazi Committed by GitHub
Browse files

Benchmark - Support autotuning in cublaslt gemm (#706)

**Description**
Enable autotuning as an opt-in mode when benchmarking cublasLt via
`cublaslt_gemm`

The implementation is based on
https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemmSimpleAutoTuning/sample_cublasLt_LtSgemmSimpleAutoTuning.cu

The behavior of original benchmark command remains unchanged, e.g.:
- `cublaslt_gemm -m 2048 -n 12288 -k 1536 -w10000 -i 1000 -t fp8e4m3`

The new opt-in options are `-a` (for autotune) and `-I` (for autotune
iterations, default is 50, same as the default for `-i`) and `-W` (for
autotune warmups, default=20, same as the default for `-w`), e.g.:
- `cublaslt_gemm -m 2048 -n 12288 -k 1536 -w 10000 -i 1000 -t fp8e4m3
-a`
- `cublaslt_gemm -m 2048 -n 12288 -k 1536 -w 10000 -i 1000 -t fp8e4m3 -a
-I 10 -W 10`

**Note:** This PR also changes the default `gemm_compute_type` for BF16
and FP16 to `CUBLAS_COMPUTE_32F`.

**Further observations:** 
1. The support matrix of the `cublaslt_gemm` could be furt...
parent 0b8d1fd4
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Micro benchmark example for cuBLASLt GEMM performance benchmark.
Commands to run:
python3 examples/benchmarks/cublaslt_function.py
"""
from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger
if __name__ == '__main__':
# Basic usage without autotune
print('Running cuBLASLt benchmark without autotune...')
parameters = '--num_warmup 10 --num_steps 50 --shapes 512,512,512 --in_types fp16 fp32'
context = BenchmarkRegistry.create_benchmark_context('cublaslt-gemm', platform=Platform.CUDA, parameters=parameters)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
# Enhanced usage with autotune enabled
print('\nRunning cuBLASLt benchmark with autotune enabled...')
parameters_autotune = (
'--num_warmup 10 --num_steps 50 '
'--shapes 512,512,512 1024,1024,1024 --in_types fp16 fp32 '
'--enable_autotune --num_warmup_autotune 20 --num_steps_autotune 50'
)
context_autotune = BenchmarkRegistry.create_benchmark_context(
'cublaslt-gemm', platform=Platform.CUDA, parameters=parameters_autotune
)
benchmark_autotune = BenchmarkRegistry.launch_benchmark(context_autotune)
if benchmark_autotune:
logger.info(
'benchmark with autotune: {}, return code: {}, result: {}'.format(
benchmark_autotune.name, benchmark_autotune.return_code, benchmark_autotune.result
)
)
# FP8 specific usage with autotune
print('\nRunning cuBLASLt benchmark with FP8 and autotune...')
parameters_fp8 = (
'--num_warmup 5 --num_steps 20 '
'--shapes 512,512,512 --in_types fp8e4m3 fp8e5m2 '
'--enable_autotune --num_warmup_autotune 10 --num_steps_autotune 30'
)
context_fp8 = BenchmarkRegistry.create_benchmark_context(
'cublaslt-gemm', platform=Platform.CUDA, parameters=parameters_fp8
)
benchmark_fp8 = BenchmarkRegistry.launch_benchmark(context_fp8)
if benchmark_fp8:
logger.info(
'FP8 benchmark with autotune: {}, return code: {}, result: {}'.format(
benchmark_fp8.name, benchmark_fp8.return_code, benchmark_fp8.result
)
)
...@@ -36,6 +36,26 @@ def add_parser_arguments(self): ...@@ -36,6 +36,26 @@ def add_parser_arguments(self):
required=False, required=False,
help='List of input data types, support {}.'.format(' '.join(self._in_types)), help='List of input data types, support {}.'.format(' '.join(self._in_types)),
) )
self._parser.add_argument(
'--enable_autotune',
action='store_true',
required=False,
help='Enable exhaustive autotune mode to find best algorithm.',
)
self._parser.add_argument(
'--num_warmup_autotune',
type=int,
default=20,
required=False,
help='Number of warm up steps for autotune.',
)
self._parser.add_argument(
'--num_steps_autotune',
type=int,
default=50,
required=False,
help='Number of steps to measure for autotune.',
)
def _preprocess(self): def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking. """Preprocess/preparation operations before the benchmarking.
...@@ -50,9 +70,16 @@ def _preprocess(self): ...@@ -50,9 +70,16 @@ def _preprocess(self):
self._commands = [] self._commands = []
for _m, _n, _k, _b, _in_type in self._shapes_to_run: for _m, _n, _k, _b, _in_type in self._shapes_to_run:
# pull out the autotune args onto their own short f-string
autotune_args = (
f' -a -W {self._args.num_warmup_autotune}'
f' -I {self._args.num_steps_autotune}'
) if self._args.enable_autotune else ''
self._commands.append( self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} ' f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}' f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
f'{(" " + autotune_args) if autotune_args else ""}'
) )
return True return True
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
set(CMAKE_CUDA_RUNTIME_LIBRARY SHARED)
project(cublaslt_gemm LANGUAGES CXX) project(cublaslt_gemm LANGUAGES CXX)
find_package(CUDAToolkit QUIET) find_package(CUDAToolkit QUIET)
......
...@@ -25,16 +25,24 @@ struct Args { ...@@ -25,16 +25,24 @@ struct Args {
int batch = 0; int batch = 0;
int warmup = 20; int warmup = 20;
int iter = 50; int iter = 50;
// Default warmup iterations for autotune
int warmup_autotune = 20;
// Default repeat iterations for autotune
int iter_autotune = 50;
std::string in_type = "fp8e4m3"; std::string in_type = "fp8e4m3";
bool autotune = false;
}; };
void process_args(int argc, char **argv, Args *args) { void process_args(int argc, char **argv, Args *args) {
const char *const short_opts = "m:n:k:b:w:i:t:"; const char *const short_opts = "m:n:k:b:w:i:t:aI:W:";
const option long_opts[] = { const option long_opts[] = {
{"batch", required_argument, nullptr, 'b'}, {"batch", required_argument, nullptr, 'b'},
{"warmup", required_argument, nullptr, 'w'}, {"warmup", required_argument, nullptr, 'w'},
{"iter", required_argument, nullptr, 'i'}, {"iter", required_argument, nullptr, 'i'},
{"in_type", required_argument, nullptr, 't'}, {"in_type", required_argument, nullptr, 't'},
{"autotune", no_argument, nullptr, 'a'},
{"iter-autotune", required_argument, nullptr, 'I'},
{"warmup-autotune", required_argument, nullptr, 'W'},
}; };
int opt = 0; int opt = 0;
...@@ -61,6 +69,15 @@ void process_args(int argc, char **argv, Args *args) { ...@@ -61,6 +69,15 @@ void process_args(int argc, char **argv, Args *args) {
case 't': case 't':
args->in_type = std::string(optarg); args->in_type = std::string(optarg);
break; break;
case 'a':
args->autotune = true;
break;
case 'I':
args->iter_autotune = std::stoi(optarg);
break;
case 'W':
args->warmup_autotune = std::stoi(optarg);
break;
} }
} }
} }
...@@ -91,7 +108,8 @@ template <typename T> cudaDataType_t get_datatype() { ...@@ -91,7 +108,8 @@ template <typename T> cudaDataType_t get_datatype() {
} }
template <typename Ta, typename Tb, typename Tout> template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter) { float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter, bool autotune,
int iter_autotune, int warmup_autotune) {
// init matrix // init matrix
Ta *matrix_a = nullptr; Ta *matrix_a = nullptr;
Tb *matrix_b = nullptr; Tb *matrix_b = nullptr;
...@@ -112,7 +130,16 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i ...@@ -112,7 +130,16 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i
CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT); CUBLAS_OP_T, CUBLAS_OP_N, CUBLASLT_EPILOGUE_DEFAULT);
void *workspace = nullptr; void *workspace = nullptr;
size_t workspace_size = gemm->GetAlgorithm(1, 2 * m * n); size_t workspace_size;
if (autotune) {
workspace_size = gemm->GetAlgorithmExhaustive(
8, 2 * m * n, 1.0f, 0.0f, reinterpret_cast<void *>(matrix_a), reinterpret_cast<void *>(matrix_b),
reinterpret_cast<void *>(matrix_out), reinterpret_cast<void *>(matrix_out), iter_autotune, warmup_autotune);
} else {
workspace_size = gemm->GetAlgorithm(1, 2 * m * n);
}
cudaMalloc(&workspace, workspace_size); cudaMalloc(&workspace, workspace_size);
// timer // timer
...@@ -142,8 +169,9 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i ...@@ -142,8 +169,9 @@ float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, i
return (time * 1e3 / iter); return (time * 1e3 / iter);
} }
template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(Args *args) { template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(const Args *args) {
float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter); float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter,
args->autotune, args->iter_autotune, args->warmup_autotune);
// m n k batch time_us tflops // m n k batch time_us tflops
printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us, printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1)); float(args->m) * float(args->n) * float(2 * args->k - 1) / 1e6 / time_us * std::max(args->batch, 1));
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#include "cublaslt_utils.h" #include "cublaslt_utils.h"
#include <algorithm> // for std::sort
#include <cassert> // for assert
void cublasLtGemm::Init() { void cublasLtGemm::Init() {
cublasLtHandle_t handle; cublasLtHandle_t handle;
...@@ -20,6 +22,11 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -20,6 +22,11 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
void *a_scale_inverse, /* only need to be set for fp8 */ void *a_scale_inverse, /* only need to be set for fp8 */
void *b_scale_inverse /* only need to be set for fp8 */ void *b_scale_inverse /* only need to be set for fp8 */
) { ) {
// Store dimensions
m_ = m;
n_ = n;
k_ = k;
cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr; cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
// force c_type // force c_type
cudaDataType_t c_type = d_type; cudaDataType_t c_type = d_type;
...@@ -33,7 +40,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -33,7 +40,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
// strided batch gemm // strided batch gemm
if (batch > 0) { if (batch > 0) {
int64_t stridea = m * k, strideb = k * n, stridec = m * n, strided = m * n; int64_t stridea = static_cast<int64_t>(m) * k, strideb = static_cast<int64_t>(k) * n,
stridec = static_cast<int64_t>(m) * n, strided = static_cast<int64_t>(m) * n;
CUBLAS_CHECK( CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
...@@ -56,23 +64,35 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -56,23 +64,35 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
c_desc_.reset(c_desc); c_desc_.reset(c_desc);
d_desc_.reset(d_desc); d_desc_.reset(d_desc);
// default to tf32 except for e5m2 inputs where the config is not supported // Set compute type and scale type based on input types
cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; cublasComputeType_t gemm_compute_type;
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) cudaDataType_t scale_type;
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
gemm_compute_type = CUBLAS_COMPUTE_32F; gemm_compute_type = CUBLAS_COMPUTE_32F;
if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) scale_type = CUDA_R_32F;
} else if (a_type == CUDA_R_16F || b_type == CUDA_R_16F || a_type == CUDA_R_16BF || b_type == CUDA_R_16BF) {
gemm_compute_type = CUBLAS_COMPUTE_32F;
scale_type = CUDA_R_32F;
} else if (a_type == CUDA_R_64F || b_type == CUDA_R_64F) {
gemm_compute_type = CUBLAS_COMPUTE_64F; gemm_compute_type = CUBLAS_COMPUTE_64F;
if (a_type == CUDA_R_8I) scale_type = CUDA_R_64F;
} else if (a_type == CUDA_R_8I) {
gemm_compute_type = CUBLAS_COMPUTE_32I; gemm_compute_type = CUBLAS_COMPUTE_32I;
scale_type = CUDA_R_32I;
} else {
gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
scale_type = CUDA_R_32F;
}
cublasLtMatmulDesc_t op_desc = nullptr; cublasLtMatmulDesc_t op_desc = nullptr;
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F)); CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, scale_type));
op_desc_.reset(op_desc); op_desc_.reset(op_desc);
if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) { if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
// disable fastAccuMode, set to 0
int8_t fastAccuMode = 1; int8_t fastAccuMode = 1;
cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode)); CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode,
sizeof(fastAccuMode)));
} }
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
...@@ -93,10 +113,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -93,10 +113,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) { size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size))); &max_workspace_size, sizeof(max_workspace_size)));
int found_algorithm_count = 0; int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count); std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
// Though we query all of possible algorithm, we will use the first later
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count, c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
results.data(), &found_algorithm_count)); results.data(), &found_algorithm_count));
...@@ -109,15 +127,120 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_ ...@@ -109,15 +127,120 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_
return heuristic_results_.front().workspaceSize; return heuristic_results_.front().workspaceSize;
} }
size_t cublasLtGemm::GetAlgorithmExhaustive(int max_algorithm_count, size_t max_workspace_size, float alpha, float beta,
void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d,
int repeat_iterations, int warmup_iterations) {
// Set workspace size in preference
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size)));
// Get heuristic algorithms
int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
results.data(), &found_algorithm_count));
if (found_algorithm_count == 0) {
throw std::runtime_error("Unable to find any suitable algorithms");
}
results.resize(found_algorithm_count);
heuristic_results_ = std::move(results);
// Create stream and events for timing
cudaStream_t stream;
cudaEvent_t startEvent, stopEvent;
cudaStreamCreate(&stream);
cudaEventCreate(&startEvent);
cudaEventCreate(&stopEvent);
// Test each algorithm multiple times to find the best one
std::vector<float> algoTimes(repeat_iterations);
// Allocate workspace
void *workspace = nullptr;
cudaMalloc(&workspace, max_workspace_size);
// Test each algorithm
algo_metrics_.clear();
algo_metrics_.reserve(found_algorithm_count);
for (int algoIdx = 0; algoIdx < found_algorithm_count; algoIdx++) {
// Skip algorithms that require more workspace than available
if (heuristic_results_[algoIdx].workspaceSize > max_workspace_size) {
continue;
}
// warmup
for (int warmupIdx = 0; warmupIdx < warmup_iterations; warmupIdx++) {
cublasStatus_t status =
cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
&beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
&heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);
}
// Test each algorithm multiple times
cudaEventRecord(startEvent, stream);
for (int checkIdx = 0; checkIdx < repeat_iterations; checkIdx++) {
cublasStatus_t status =
cublasLtMatmul(handle_.get(), op_desc_.get(), &alpha, matrix_a, a_desc_.get(), matrix_b, b_desc_.get(),
&beta, matrix_c, c_desc_.get(), matrix_d, d_desc_.get(),
&heuristic_results_[algoIdx].algo, workspace, max_workspace_size, stream);
// Skip if algorithm fails
if (status != CUBLAS_STATUS_SUCCESS) {
algoTimes[checkIdx] = std::numeric_limits<float>::max();
continue;
}
}
cudaEventRecord(stopEvent, stream);
cudaEventSynchronize(stopEvent);
float time = 0;
cudaEventElapsedTime(&time, startEvent, stopEvent);
algoTimes[algoIdx] = time / repeat_iterations;
float meanTime = algoTimes[algoIdx];
float flops = 2.0f * m_ * n_ * k_ / (meanTime * 1e-3f);
// Store metrics
AlgorithmMetrics metrics;
metrics.algo = heuristic_results_[algoIdx].algo;
metrics.workspace_size = heuristic_results_[algoIdx].workspaceSize;
metrics.time = meanTime;
metrics.flops = flops;
algo_metrics_.push_back(metrics);
}
std::sort(algo_metrics_.begin(), algo_metrics_.end(),
[](const AlgorithmMetrics &a, const AlgorithmMetrics &b) { return a.time < b.time; });
if (!algo_metrics_.empty())
heuristic_results_[0].algo = algo_metrics_.front().algo;
// Clean up resources
cudaFree(workspace);
cudaEventDestroy(startEvent);
cudaEventDestroy(stopEvent);
cudaStreamDestroy(stream);
if (!algo_metrics_.empty()) {
return algo_metrics_.front().workspace_size;
}
throw std::runtime_error("No valid algorithms found during autotune");
}
void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta, void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream) { void *workspace, size_t workspace_size, cudaStream_t stream) {
CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */ CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */ matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */ a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */ b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */ matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */ c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */ d_desc_.get(), &heuristic_results_.front().algo, workspace, /* workspace */
workspace, /* workspace */
workspace_size, stream)); /* stream */ workspace_size, stream)); /* stream */
} }
...@@ -51,9 +51,21 @@ class cublasLtGemm { ...@@ -51,9 +51,21 @@ class cublasLtGemm {
size_t GetAlgorithm(int max_algorithm_count, size_t max_workspace_size); size_t GetAlgorithm(int max_algorithm_count, size_t max_workspace_size);
size_t GetAlgorithmExhaustive(int max_algorithm_count, size_t max_workspace_size, float alpha, float beta,
void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d,
int repeat_iterations = 100, int warmup_iterations = 100);
void Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta, void Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream); void *workspace, size_t workspace_size, cudaStream_t stream);
// Type to store algorithm performance metrics
struct AlgorithmMetrics {
cublasLtMatmulAlgo_t algo;
size_t workspace_size;
float time;
float flops;
};
private: private:
UniqueHandle handle_; UniqueHandle handle_;
UniqueOpDesc op_desc_; UniqueOpDesc op_desc_;
...@@ -63,4 +75,10 @@ class cublasLtGemm { ...@@ -63,4 +75,10 @@ class cublasLtGemm {
UniqueLayoutDesc d_desc_; UniqueLayoutDesc d_desc_;
UniqueMatmulPreference preference_; UniqueMatmulPreference preference_;
std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results_; std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results_;
std::vector<AlgorithmMetrics> algo_metrics_;
cublasComputeType_t compute_type_ = CUBLAS_COMPUTE_32F;
cudaDataType_t scale_type_ = CUDA_R_32F;
int m_ = 0;
int n_ = 0;
int k_ = 0;
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment