Commit 3ec8c534 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
extern crate bindgen;
extern crate cc;
use std::{env, path::PathBuf, process::Command};
use bindgen::CargoCallbacks;
use regex::Regex;
fn main() {
// Tell cargo to invalidate the built crate whenever files of interest changes.
println!("cargo:rerun-if-changed={}", "cuda");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
// Specify the desired architecture version.
let arch = "compute_86"; // For example, using SM 8.6 (Ampere architecture).
let code = "sm_86"; // For the same SM 8.6 (Ampere architecture).
// build the cuda kernels
let cuda_src = PathBuf::from("src/cuda/kernels/my_struct_kernel.cu");
let ptx_file = out_dir.join("my_struct_kernel.ptx");
let nvcc_status = Command::new("nvcc")
.arg("-ptx")
.arg("-o")
.arg(&ptx_file)
.arg(&cuda_src)
.arg(format!("-arch={}", arch))
.arg(format!("-code={}", code))
.status()
.unwrap();
assert!(
nvcc_status.success(),
"Failed to compile CUDA source to PTX."
);
// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("src/cuda/includes/wrapper.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(CargoCallbacks))
// we use "no_copy" and "no_debug" here because we don't know if we can safely generate them for our structs in C code (they may contain raw pointers)
.no_copy("*")
.no_debug("*")
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");
// we need to make modifications to the generated code
let generated_bindings = bindings.to_string();
// Regex to find raw pointers to float and replace them with CudaSlice<f32>
// You can copy this regex to add/modify other types of pointers, for example "*mut i32"
let pointer_regex = Regex::new(r"\*mut f32").unwrap();
let modified_bindings = pointer_regex.replace_all(&generated_bindings, "CudaSlice<f32>");
// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
std::fs::write(out_path.join("bindings.rs"), modified_bindings.as_bytes())
.expect("Failed to write bindings");
}
// my_struct.h
#ifndef MY_STRUCT_H
#define MY_STRUCT_H
struct MyStruct
{
/// example data containing fixed length array of floats
float data[4];
};
#endif // MY_STRUCT_H
#include "my_struct.h"
\ No newline at end of file
#include "../includes/my_struct.h"
extern "C" __global__ void my_struct_kernel(MyStruct *my_structs, const size_t n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n)
{
my_structs[i].data[0] += 1;
my_structs[i].data[1] += 1;
my_structs[i].data[2] += 1;
my_structs[i].data[3] += 1;
}
}
\ No newline at end of file
#![allow(non_snake_case)]
//! This file outlines a typical build process which can be used for more complex CUDA projects utilising this crate.
//! It does the following:
//! 1. Use a `build.rs` file to compile your CUDA code/project into a PTX file. Your CUDA code/project can be as complicated as you need them to be, including multiple files, with headers for your struct definitions, each kernel in it's own file, etc.
//! 2. The build process compiles the kernels into a PTX file, which is written to the output directory
//! 3. The build process then uses the `bindgen` crate to generate Rust bindings for the structs defined in your CUDA code
//! 4. In the `main.rs` code, the PTX code is included as a string via the `!include_str` macro, which is then compiled using the functions in this crate (detailed in previous examples)
//!
//! The advantages of having this build process for more complex CUDA projects:
//! - You only need to define your structs once, in your CUDA code, and the Rust bindings are generated automatically
//! - You have full intellisense for your CUDA code since they can be stored under a separate folder or even as part of a separate project
//!
//! There are two files in this example: `main.rs` and `build.rs`. You can reference them and add to your project accordingly. The `cuda` folder in this example gives a simple example of defining structs in a separate header, including creating a `wrapper.h` header for `bindgen`
use std::time::Instant;
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
use cudarc::nvrtc::Ptx;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
unsafe impl DeviceRepr for MyStruct {}
impl Default for MyStruct {
fn default() -> Self{
Self{ data: [0.0; 4]}
}
}
// include the compiled PTX code as string
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/my_struct_kernel.ptx"));
fn main() -> Result<(), DriverError> {
// setup GPU device
let now = Instant::now();
let gpu = CudaDevice::new(0)?;
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
// compile ptx
let now = Instant::now();
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
gpu.load_ptx(ptx, "my_module", &["my_struct_kernel"])?;
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
// create data
let now = Instant::now();
let n = 10_usize;
let my_structs = vec![MyStruct { data: [1.0; 4] }; n];
// copy to GPU
let gpu_my_structs = gpu.htod_copy(my_structs)?;
println!("Time taken to initialise data: {:.2?}", now.elapsed());
let now = Instant::now();
let f = gpu.get_func("my_module", "my_struct_kernel").unwrap();
unsafe { f.launch(LaunchConfig::for_num_elems(n as u32), (&gpu_my_structs, n)) }?;
println!("Time taken to call kernel: {:.2?}", now.elapsed());
let my_structs = gpu.sync_reclaim(gpu_my_structs)?;
assert!(my_structs.iter().all(|i| i.data == [1.0; 4]));
Ok(())
}
use cudarc::driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig};
use cudarc::nvrtc::compile_ptx;
const PTX_SRC: &str = "
extern \"C\" __global__ void matmul(float* A, float* B, float* C, int N) {
int ROW = blockIdx.y*blockDim.y+threadIdx.y;
int COL = blockIdx.x*blockDim.x+threadIdx.x;
float tmpSum = 0;
if (ROW < N && COL < N) {
// each thread computes one element of the block sub-matrix
for (int i = 0; i < N; i++) {
tmpSum += A[ROW * N + i] * B[i * N + COL];
}
}
// printf(\"pos, (%d, %d) - N %d - value %d\\n\", ROW, COL, N, tmpSum);
C[ROW * N + COL] = tmpSum;
}
";
fn main() -> Result<(), DriverError> {
let start = std::time::Instant::now();
let ptx = compile_ptx(PTX_SRC).unwrap();
println!("Compilation succeeded in {:?}", start.elapsed());
let dev = CudaDevice::new(0)?;
println!("Built in {:?}", start.elapsed());
dev.load_ptx(ptx, "matmul", &["matmul"])?;
let f = dev.get_func("matmul", "matmul").unwrap();
println!("Loaded in {:?}", start.elapsed());
let a_host = [1.0f32, 2.0, 3.0, 4.0];
let b_host = [1.0f32, 2.0, 3.0, 4.0];
let mut c_host = [0.0f32; 4];
let a_dev = dev.htod_sync_copy(&a_host)?;
let b_dev = dev.htod_sync_copy(&b_host)?;
let mut c_dev = dev.htod_sync_copy(&c_host)?;
println!("Copied in {:?}", start.elapsed());
let cfg = LaunchConfig {
block_dim: (2, 2, 1),
grid_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
unsafe { f.launch(cfg, (&a_dev, &b_dev, &mut c_dev, 2i32)) }?;
dev.dtoh_sync_copy_into(&c_dev, &mut c_host)?;
println!("Found {:?} in {:?}", c_host, start.elapsed());
Ok(())
}
use cudarc::nvrtc::{compile_ptx_with_opts, CompileError, CompileOptions};
fn main() -> Result<(), CompileError> {
let opts = CompileOptions {
ftz: Some(true),
prec_div: Some(false),
prec_sqrt: Some(false),
fmad: Some(true),
..Default::default()
};
let _ = compile_ptx_with_opts(
"
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}",
opts,
)?;
println!("Compilation succeeded!");
Ok(())
}
extern "C" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}
\ No newline at end of file
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-29745058
// Cuda compilation tools, release 11.3, V11.3.58
// Based on NVVM 7.0.1
//
.version 7.3
.target sm_52
.address_size 64
// .globl sin_kernel
.global .align 4 .b8 __cudart_i2opi_f[24] = {65, 144, 67, 60, 153, 149, 98, 219, 192, 221, 52, 245, 209, 87, 39, 252, 41, 21, 68, 78, 110, 131, 249, 162};
.visible .entry sin_kernel(
.param .u64 sin_kernel_param_0,
.param .u64 sin_kernel_param_1,
.param .u32 sin_kernel_param_2
)
{
.local .align 4 .b8 __local_depot0[28];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<12>;
.reg .f32 %f<38>;
.reg .b32 %r<53>;
.reg .f64 %fd<3>;
.reg .b64 %rd<33>;
mov.u64 %SPL, __local_depot0;
ld.param.u64 %rd10, [sin_kernel_param_0];
ld.param.u64 %rd11, [sin_kernel_param_1];
ld.param.u32 %r19, [sin_kernel_param_2];
add.u64 %rd1, %SPL, 0;
mov.u32 %r20, %ntid.x;
mov.u32 %r21, %ctaid.x;
mov.u32 %r22, %tid.x;
mad.lo.s32 %r1, %r21, %r20, %r22;
setp.ge.s32 %p1, %r1, %r19;
@%p1 bra $L__BB0_14;
cvta.to.global.u64 %rd13, %rd11;
cvt.s64.s32 %rd2, %r1;
mul.wide.s32 %rd14, %r1, 4;
add.s64 %rd15, %rd13, %rd14;
ld.global.f32 %f1, [%rd15];
mul.f32 %f14, %f1, 0f3F22F983;
cvt.rni.s32.f32 %r52, %f14;
cvt.rn.f32.s32 %f15, %r52;
mov.f32 %f16, 0fBFC90FDA;
fma.rn.f32 %f17, %f15, %f16, %f1;
mov.f32 %f18, 0fB3A22168;
fma.rn.f32 %f19, %f15, %f18, %f17;
mov.f32 %f20, 0fA7C234C5;
fma.rn.f32 %f35, %f15, %f20, %f19;
abs.f32 %f3, %f1;
setp.leu.f32 %p2, %f3, 0f47CE4780;
@%p2 bra $L__BB0_9;
setp.eq.f32 %p3, %f3, 0f7F800000;
@%p3 bra $L__BB0_8;
bra.uni $L__BB0_3;
$L__BB0_8:
mov.f32 %f23, 0f00000000;
mul.rn.f32 %f35, %f1, %f23;
bra.uni $L__BB0_9;
$L__BB0_3:
mov.b32 %r3, %f1;
bfe.u32 %r24, %r3, 23, 8;
add.s32 %r4, %r24, -128;
shl.b32 %r25, %r3, 8;
or.b32 %r5, %r25, -2147483648;
shr.u32 %r6, %r4, 5;
mov.u64 %rd32, 0;
mov.u32 %r49, 0;
mov.u64 %rd30, __cudart_i2opi_f;
mov.u64 %rd31, %rd1;
$L__BB0_4:
.pragma "nounroll";
ld.global.nc.u32 %r26, [%rd30];
mad.wide.u32 %rd18, %r26, %r5, %rd32;
shr.u64 %rd32, %rd18, 32;
st.local.u32 [%rd31], %rd18;
add.s64 %rd31, %rd31, 4;
add.s64 %rd30, %rd30, 4;
add.s32 %r49, %r49, 1;
setp.ne.s32 %p4, %r49, 6;
@%p4 bra $L__BB0_4;
st.local.u32 [%rd1+24], %rd32;
cvt.u64.u32 %rd19, %r6;
mov.u64 %rd20, 2;
sub.s64 %rd21, %rd20, %rd19;
shl.b64 %rd22, %rd21, 2;
add.s64 %rd23, %rd1, %rd22;
add.s64 %rd9, %rd23, 16;
ld.local.u32 %r50, [%rd23+16];
ld.local.u32 %r51, [%rd23+12];
and.b32 %r11, %r4, 31;
setp.eq.s32 %p5, %r11, 0;
@%p5 bra $L__BB0_7;
mov.u32 %r27, 32;
sub.s32 %r28, %r27, %r11;
shr.u32 %r29, %r51, %r28;
shl.b32 %r30, %r50, %r11;
add.s32 %r50, %r29, %r30;
ld.local.u32 %r31, [%rd9+-8];
shr.u32 %r32, %r31, %r28;
shl.b32 %r33, %r51, %r11;
add.s32 %r51, %r32, %r33;
$L__BB0_7:
and.b32 %r34, %r3, -2147483648;
shr.u32 %r35, %r51, 30;
shl.b32 %r36, %r50, 2;
or.b32 %r37, %r35, %r36;
shr.u32 %r38, %r37, 31;
shr.u32 %r39, %r50, 30;
add.s32 %r40, %r38, %r39;
neg.s32 %r41, %r40;
setp.eq.s32 %p6, %r34, 0;
selp.b32 %r52, %r40, %r41, %p6;
setp.ne.s32 %p7, %r38, 0;
xor.b32 %r42, %r34, -2147483648;
selp.b32 %r43, %r42, %r34, %p7;
selp.b32 %r44, -1, 0, %p7;
xor.b32 %r45, %r37, %r44;
shl.b32 %r46, %r51, 2;
xor.b32 %r47, %r46, %r44;
cvt.u64.u32 %rd24, %r45;
cvt.u64.u32 %rd25, %r47;
bfi.b64 %rd26, %rd24, %rd25, 32, 32;
cvt.rn.f64.s64 %fd1, %rd26;
mul.f64 %fd2, %fd1, 0d3BF921FB54442D19;
cvt.rn.f32.f64 %f21, %fd2;
setp.eq.s32 %p8, %r43, 0;
neg.f32 %f22, %f21;
selp.f32 %f35, %f21, %f22, %p8;
$L__BB0_9:
and.b32 %r18, %r52, 1;
setp.eq.s32 %p9, %r18, 0;
selp.f32 %f7, %f35, 0f3F800000, %p9;
mul.rn.f32 %f8, %f35, %f35;
mov.f32 %f36, 0fB94D4153;
@%p9 bra $L__BB0_11;
mov.f32 %f25, 0fBAB607ED;
mov.f32 %f26, 0f37CBAC00;
fma.rn.f32 %f36, %f26, %f8, %f25;
$L__BB0_11:
selp.f32 %f27, 0f3C0885E4, 0f3D2AAABB, %p9;
fma.rn.f32 %f28, %f36, %f8, %f27;
selp.f32 %f29, 0fBE2AAAA8, 0fBEFFFFFF, %p9;
fma.rn.f32 %f30, %f28, %f8, %f29;
mov.f32 %f31, 0f00000000;
fma.rn.f32 %f32, %f8, %f7, %f31;
fma.rn.f32 %f37, %f30, %f32, %f7;
and.b32 %r48, %r52, 2;
setp.eq.s32 %p11, %r48, 0;
@%p11 bra $L__BB0_13;
mov.f32 %f34, 0fBF800000;
fma.rn.f32 %f37, %f37, %f34, %f31;
$L__BB0_13:
cvta.to.global.u64 %rd27, %rd10;
shl.b64 %rd28, %rd2, 2;
add.s64 %rd29, %rd27, %rd28;
st.global.f32 [%rd29], %f37;
$L__BB0_14:
ret;
}
#!/bin/bash
set -exu
bindgen \
--whitelist-type="^cublas.*" \
--whitelist-function="^cublas.*" \
--default-enum-style=rust \
--no-doc-comments \
--with-derive-default \
--with-derive-eq \
--with-derive-hash \
--with-derive-ord \
--size_t-is-usize \
--use-core \
wrapper.h -- -I/usr/local/cuda/include \
> sys.rs
\ No newline at end of file
use super::sys;
use core::ffi::{c_int, c_longlong};
use half::f16;
extern "C" {
pub fn cublasHgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f16,
A: *const f16,
lda: c_int,
B: *const f16,
ldb: c_int,
beta: *const f16,
C: *mut f16,
ldc: c_int,
) -> sys::cublasStatus_t;
}
extern "C" {
pub fn cublasHgemmStridedBatched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f16,
A: *const f16,
lda: c_int,
strideA: c_longlong,
B: *const f16,
ldb: c_int,
strideB: c_longlong,
beta: *const f16,
C: *mut f16,
ldc: c_int,
strideC: c_longlong,
batchCount: c_int,
) -> sys::cublasStatus_t;
}
//! Wrappers around the [cublas API](https://docs.nvidia.com/cuda/cublas/index.html),
//! in three levels. See crate documentation for description of each.
#[cfg(feature = "f16")]
pub mod half;
pub mod result;
pub mod safe;
#[allow(warnings)]
pub mod sys;
pub use safe::*;
use super::sys;
use core::ffi::{c_int, c_longlong, c_void};
use core::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CublasError(pub sys::cublasStatus_t);
impl sys::cublasStatus_t {
pub fn result(self) -> Result<(), CublasError> {
match self {
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
_ => Err(CublasError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CublasError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CublasError {}
/// Creates a handle to the cuBLAS library. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublascreate)
pub fn create_handle() -> Result<sys::cublasHandle_t, CublasError> {
let mut handle = MaybeUninit::uninit();
unsafe {
sys::cublasCreate_v2(handle.as_mut_ptr()).result()?;
Ok(handle.assume_init())
}
}
/// Destroys a handle previously created with [create_handle()]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasdestroy)
///
/// # Safety
///
/// `handle` must not have been freed already.
pub unsafe fn destroy_handle(handle: sys::cublasHandle_t) -> Result<(), CublasError> {
sys::cublasDestroy_v2(handle).result()
}
/// Sets the stream cuBLAS will use. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublassetstream)
///
/// # Safety
///
/// `handle` and `stream` must be valid.
pub unsafe fn set_stream(
handle: sys::cublasHandle_t,
stream: sys::cudaStream_t,
) -> Result<(), CublasError> {
sys::cublasSetStream_v2(handle, stream).result()
}
/// Single precision matrix vector multiplication. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemv)
///
/// # Safety
///
/// - `a`, `x`, and `y` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn sgemv(
handle: sys::cublasHandle_t,
trans: sys::cublasOperation_t,
m: c_int,
n: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
x: *const f32,
incx: c_int,
beta: *const f32,
y: *mut f32,
incy: c_int,
) -> Result<(), CublasError> {
sys::cublasSgemv_v2(handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy).result()
}
/// Double precision matrix vector multiplication. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemv)
///
/// # Safety
///
/// - `a`, `x`, and `y` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn dgemv(
handle: sys::cublasHandle_t,
trans: sys::cublasOperation_t,
m: c_int,
n: c_int,
alpha: *const f64,
a: *const f64,
lda: c_int,
x: *const f64,
incx: c_int,
beta: *const f64,
y: *mut f64,
incy: c_int,
) -> Result<(), CublasError> {
sys::cublasDgemv_v2(handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy).result()
}
#[cfg(feature = "f16")]
/// Half precision matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn hgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const half::f16,
a: *const half::f16,
lda: c_int,
b: *const half::f16,
ldb: c_int,
beta: *const half::f16,
c: *mut half::f16,
ldc: c_int,
) -> Result<(), CublasError> {
super::half::cublasHgemm(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
)
.result()
}
/// Single precision matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn sgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
b: *const f32,
ldb: c_int,
beta: *const f32,
c: *mut f32,
ldc: c_int,
) -> Result<(), CublasError> {
sys::cublasSgemm_v2(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
)
.result()
}
/// Double precision matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn dgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f64,
a: *const f64,
lda: c_int,
b: *const f64,
ldb: c_int,
beta: *const f64,
c: *mut f64,
ldc: c_int,
) -> Result<(), CublasError> {
sys::cublasDgemm_v2(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
)
.result()
}
#[cfg(feature = "f16")]
/// Half precision batched matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn hgemm_strided_batched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const half::f16,
a: *const half::f16,
lda: c_int,
stride_a: c_longlong,
b: *const half::f16,
ldb: c_int,
stride_b: c_longlong,
beta: *const half::f16,
c: *mut half::f16,
ldc: c_int,
stride_c: c_longlong,
batch_size: c_int,
) -> Result<(), CublasError> {
super::half::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc,
stride_c, batch_size,
)
.result()
}
/// Single precision batched matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn sgemm_strided_batched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
stride_a: c_longlong,
b: *const f32,
ldb: c_int,
stride_b: c_longlong,
beta: *const f32,
c: *mut f32,
ldc: c_int,
stride_c: c_longlong,
batch_size: c_int,
) -> Result<(), CublasError> {
sys::cublasSgemmStridedBatched(
handle, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc,
stride_c, batch_size,
)
.result()
}
/// Double precision batched matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn dgemm_strided_batched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f64,
a: *const f64,
lda: c_int,
stride_a: c_longlong,
b: *const f64,
ldb: c_int,
stride_b: c_longlong,
beta: *const f64,
c: *mut f64,
ldc: c_int,
stride_c: c_longlong,
batch_size: c_int,
) -> Result<(), CublasError> {
sys::cublasDgemmStridedBatched(
handle, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc,
stride_c, batch_size,
)
.result()
}
/// Matmul with data types specified as parameters. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn gemm_ex(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
a_type: sys::cudaDataType,
lda: c_int,
b: *const c_void,
b_type: sys::cudaDataType,
ldb: c_int,
beta: *const c_void,
c: *mut c_void,
c_type: sys::cudaDataType,
ldc: c_int,
compute_type: sys::cublasComputeType_t,
algo: sys::cublasGemmAlgo_t,
) -> Result<(), CublasError> {
sys::cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
a,
a_type,
lda,
b,
b_type,
ldb,
beta,
c,
c_type,
ldc,
compute_type,
algo,
)
.result()
}
/// Strided batched matmul with data types specified as parameters. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn gemm_strided_batched_ex(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
a_type: sys::cudaDataType,
lda: c_int,
stride_a: c_longlong,
b: *const c_void,
b_type: sys::cudaDataType,
ldb: c_int,
stride_b: c_longlong,
beta: *const c_void,
c: *mut c_void,
c_type: sys::cudaDataType,
ldc: c_int,
stride_c: c_longlong,
batch_count: c_int,
compute_type: sys::cublasComputeType_t,
algo: sys::cublasGemmAlgo_t,
) -> Result<(), CublasError> {
sys::cublasGemmStridedBatchedEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
a,
a_type,
lda,
stride_a,
b,
b_type,
ldb,
stride_b,
beta,
c,
c_type,
ldc,
stride_c,
batch_count,
compute_type,
algo,
)
.result()
}
//! Safe abstractions around [crate::cublas::result] for doing gemm and gemv.
#![allow(clippy::too_many_arguments)]
use super::{result, result::CublasError, sys};
use crate::driver::{CudaDevice, CudaStream, DevicePtr, DevicePtrMut};
use core::ffi::{c_int, c_longlong};
use std::sync::Arc;
/// Wrapper around [sys::cublasHandle_t]
///
/// 1. Create with [CudaBlas::new()]
/// 2. Execute gemm/gemv kernels with [Gemv] and [Gemm]. Both f32 and f64 are supported
/// for both
///
/// Note: This maintains a instance of [`Arc<CudaDevice>`], so will prevent the device
/// from being dropped.
#[derive(Debug)]
pub struct CudaBlas {
pub(crate) handle: sys::cublasHandle_t,
pub(crate) device: Arc<CudaDevice>,
}
unsafe impl Send for CudaBlas {}
unsafe impl Sync for CudaBlas {}
impl CudaBlas {
/// Creates a new cublas handle and sets the stream to the `device`'s stream.
pub fn new(device: Arc<CudaDevice>) -> Result<Self, CublasError> {
device.bind_to_thread().unwrap();
let handle = result::create_handle()?;
let blas = Self { handle, device };
unsafe { result::set_stream(handle, blas.device.stream as *mut _) }?;
Ok(blas)
}
/// Returns a reference to the underlying cublas handle.
pub fn handle(&self) -> &sys::cublasHandle_t {
&self.handle
}
/// Sets the handle's current to either the stream specified, or the device's default work
/// stream.
///
/// # Safety
/// This is unsafe because you can end up scheduling multiple concurrent kernels that all
/// write to the same memory address.
pub unsafe fn set_stream(&self, opt_stream: Option<&CudaStream>) -> Result<(), CublasError> {
match opt_stream {
Some(s) => result::set_stream(self.handle, s.stream as *mut _),
None => result::set_stream(self.handle, self.device.stream as *mut _),
}
}
}
impl Drop for CudaBlas {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::destroy_handle(handle) }.unwrap();
}
}
}
/// Configuration for [Gemv]
#[derive(Debug, Copy, Clone)]
pub struct GemvConfig<T> {
pub trans: sys::cublasOperation_t,
pub m: c_int,
pub n: c_int,
pub alpha: T,
pub lda: c_int,
pub incx: c_int,
pub beta: T,
pub incy: c_int,
}
/// Matrix vector multiplication with elements of type `T`
pub trait Gemv<T> {
/// Matrix vector multiplication.
///
/// # Safety
/// This is unsafe because improper arguments may lead to invalid
/// memory accesses.
unsafe fn gemv<A: DevicePtr<T>, X: DevicePtr<T>, Y: DevicePtrMut<T>>(
&self,
cfg: GemvConfig<T>,
a: &A,
x: &X,
y: &mut Y,
) -> Result<(), CublasError>;
}
impl Gemv<f32> for CudaBlas {
unsafe fn gemv<A: DevicePtr<f32>, X: DevicePtr<f32>, Y: DevicePtrMut<f32>>(
&self,
cfg: GemvConfig<f32>,
a: &A,
x: &X,
y: &mut Y,
) -> Result<(), CublasError> {
result::sgemv(
self.handle,
cfg.trans,
cfg.m,
cfg.n,
(&cfg.alpha) as *const _,
*a.device_ptr() as *const _,
cfg.lda,
*x.device_ptr() as *const _,
cfg.incx,
(&cfg.beta) as *const _,
*y.device_ptr_mut() as *mut _,
cfg.incy,
)
}
}
impl Gemv<f64> for CudaBlas {
unsafe fn gemv<A: DevicePtr<f64>, X: DevicePtr<f64>, Y: DevicePtrMut<f64>>(
&self,
cfg: GemvConfig<f64>,
a: &A,
x: &X,
y: &mut Y,
) -> Result<(), CublasError> {
result::dgemv(
self.handle,
cfg.trans,
cfg.m,
cfg.n,
(&cfg.alpha) as *const _,
*a.device_ptr() as *const _,
cfg.lda,
*x.device_ptr() as *const _,
cfg.incx,
(&cfg.beta) as *const _,
*y.device_ptr_mut() as *mut _,
cfg.incy,
)
}
}
/// Configuration for [Gemm]
#[derive(Debug, Copy, Clone)]
pub struct GemmConfig<T> {
pub transa: sys::cublasOperation_t,
pub transb: sys::cublasOperation_t,
pub m: c_int,
pub n: c_int,
pub k: c_int,
pub alpha: T,
pub lda: c_int,
pub ldb: c_int,
pub beta: T,
pub ldc: c_int,
}
/// Configuration for [Gemm] strided batched call
#[derive(Debug, Copy, Clone)]
pub struct StridedBatchedConfig<T> {
pub gemm: GemmConfig<T>,
pub batch_size: c_int,
pub stride_a: c_longlong,
pub stride_b: c_longlong,
pub stride_c: c_longlong,
}
/// Matrix matrix multiplication with elements of type `T`.
pub trait Gemm<T> {
/// Matrix matrix multiplication. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
/// This is unsafe because improper arguments may lead to invalid
/// memory accesses.
unsafe fn gemm<A: DevicePtr<T>, B: DevicePtr<T>, C: DevicePtrMut<T>>(
&self,
cfg: GemmConfig<T>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError>;
/// Batched matrix multiplication with stride support on batch dimension. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
/// This is unsafe because improper arguments may lead to invalid
/// memory accesses.
unsafe fn gemm_strided_batched<A: DevicePtr<T>, B: DevicePtr<T>, C: DevicePtrMut<T>>(
&self,
cfg: StridedBatchedConfig<T>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError>;
}
#[cfg(feature = "f16")]
impl Gemm<half::f16> for CudaBlas {
unsafe fn gemm<A: DevicePtr<half::f16>, B: DevicePtr<half::f16>, C: DevicePtrMut<half::f16>>(
&self,
cfg: GemmConfig<half::f16>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
let alpha: f32 = cfg.alpha.to_f32();
let beta: f32 = cfg.beta.to_f32();
result::gemm_ex(
self.handle,
cfg.transa,
cfg.transb,
cfg.m,
cfg.n,
cfg.k,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.lda,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.ldb,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.ldc,
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)
}
unsafe fn gemm_strided_batched<
A: DevicePtr<half::f16>,
B: DevicePtr<half::f16>,
C: DevicePtrMut<half::f16>,
>(
&self,
cfg: StridedBatchedConfig<half::f16>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
let alpha: f32 = cfg.gemm.alpha.to_f32();
let beta: f32 = cfg.gemm.beta.to_f32();
result::gemm_strided_batched_ex(
self.handle,
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)
}
}
#[cfg(feature = "f16")]
impl Gemm<half::bf16> for CudaBlas {
unsafe fn gemm<
A: DevicePtr<half::bf16>,
B: DevicePtr<half::bf16>,
C: DevicePtrMut<half::bf16>,
>(
&self,
cfg: GemmConfig<half::bf16>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
let alpha: f32 = cfg.alpha.to_f32();
let beta: f32 = cfg.beta.to_f32();
result::gemm_ex(
self.handle,
cfg.transa,
cfg.transb,
cfg.m,
cfg.n,
cfg.k,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.lda,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.ldb,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.ldc,
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)
}
unsafe fn gemm_strided_batched<
A: DevicePtr<half::bf16>,
B: DevicePtr<half::bf16>,
C: DevicePtrMut<half::bf16>,
>(
&self,
cfg: StridedBatchedConfig<half::bf16>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
let alpha: f32 = cfg.gemm.alpha.to_f32();
let beta: f32 = cfg.gemm.beta.to_f32();
result::gemm_strided_batched_ex(
self.handle,
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT,
)
}
}
impl Gemm<f32> for CudaBlas {
unsafe fn gemm<A: DevicePtr<f32>, B: DevicePtr<f32>, C: DevicePtrMut<f32>>(
&self,
cfg: GemmConfig<f32>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
result::sgemm(
self.handle,
cfg.transa,
cfg.transb,
cfg.m,
cfg.n,
cfg.k,
(&cfg.alpha) as *const _,
*a.device_ptr() as *const _,
cfg.lda,
*b.device_ptr() as *const _,
cfg.ldb,
(&cfg.beta) as *const _,
*c.device_ptr_mut() as *mut _,
cfg.ldc,
)
}
unsafe fn gemm_strided_batched<A: DevicePtr<f32>, B: DevicePtr<f32>, C: DevicePtrMut<f32>>(
&self,
cfg: StridedBatchedConfig<f32>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
result::sgemm_strided_batched(
self.handle,
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&cfg.gemm.alpha) as *const _,
*a.device_ptr() as *const _,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
cfg.gemm.ldb,
cfg.stride_b,
(&cfg.gemm.beta) as *const _,
*c.device_ptr_mut() as *mut _,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
)
}
}
impl Gemm<f64> for CudaBlas {
unsafe fn gemm<A: DevicePtr<f64>, B: DevicePtr<f64>, C: DevicePtrMut<f64>>(
&self,
cfg: GemmConfig<f64>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
result::dgemm(
self.handle,
cfg.transa,
cfg.transb,
cfg.m,
cfg.n,
cfg.k,
(&cfg.alpha) as *const _,
*a.device_ptr() as *const _,
cfg.lda,
*b.device_ptr() as *const _,
cfg.ldb,
(&cfg.beta) as *const _,
*c.device_ptr_mut() as *mut _,
cfg.ldc,
)
}
unsafe fn gemm_strided_batched<A: DevicePtr<f64>, B: DevicePtr<f64>, C: DevicePtrMut<f64>>(
&self,
cfg: StridedBatchedConfig<f64>,
a: &A,
b: &B,
c: &mut C,
) -> Result<(), CublasError> {
result::dgemm_strided_batched(
self.handle,
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&cfg.gemm.alpha) as *const _,
*a.device_ptr() as *const _,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
cfg.gemm.ldb,
cfg.stride_b,
(&cfg.gemm.beta) as *const _,
*c.device_ptr_mut() as *mut _,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::needless_range_loop)]
use super::*;
fn gemv_truth<T, const M: usize, const N: usize>(
alpha: T,
a: &[[T; N]; M],
x: &[T; N],
beta: T,
y: &mut [T; M],
) where
T: Copy + Clone + std::ops::AddAssign + std::ops::MulAssign + std::ops::Mul<T, Output = T>,
{
for m in 0..M {
y[m] *= beta;
}
for m in 0..M {
for n in 0..N {
y[m] += alpha * a[m][n] * x[n];
}
}
}
fn gemm_truth<T, const M: usize, const N: usize, const K: usize>(
alpha: T,
a: &[[T; K]; M],
b: &[[T; N]; K],
beta: T,
c: &mut [[T; N]; M],
) where
T: Copy + Clone + std::ops::AddAssign + std::ops::MulAssign + std::ops::Mul<T, Output = T>,
{
for m in 0..M {
for n in 0..N {
c[m][n] *= beta;
}
}
for m in 0..M {
for n in 0..N {
for k in 0..K {
c[m][n] += alpha * a[m][k] * b[k][n];
}
}
}
}
#[test]
fn test_sgemv() {
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlas::new(dev.clone()).unwrap();
const M: usize = 2;
const N: usize = 5;
let a: [[f32; N]; M] = [
[0.9314776, 0.10300648, -0.620774, 1.5270752, 0.0259804],
[0.16820757, -0.94463515, -1.3850101, 1.0600523, 1.5124008],
];
#[rustfmt::skip]
let b: [f32; N] = [-1.3441996, 1.3965541, -0.89106345, 0.21196432, -0.95535654];
let mut c: [f32; M] = [1.0; M];
gemv_truth(1.0, &a, &b, 0.0, &mut c);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy(&[
0.9314776, 0.10300648, -0.620774, 1.527075, 0.0259804,
0.16820757, -0.94463515, -1.3850101, 1.0600523, 1.5124008,
]).unwrap();
let b_dev = dev.htod_sync_copy(&b).unwrap();
let mut c_dev = dev.alloc_zeros(M).unwrap();
unsafe {
blas.gemv(
GemvConfig {
trans: sys::cublasOperation_t::CUBLAS_OP_T,
m: N as i32,
n: M as i32,
alpha: 1.0,
lda: N as i32,
incx: 1,
beta: 0.0,
incy: 1,
},
&a_dev,
&b_dev,
&mut c_dev,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for i in 0..M {
assert!((c_host[i] - c[i]).abs() <= 1e-6);
}
}
#[test]
fn test_dgemv() {
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlas::new(dev.clone()).unwrap();
const M: usize = 8;
const N: usize = 3;
let a: [[f64; N]; M] = [
[0.96151888, -0.36771390, 0.94069099],
[2.20621538, -0.16479775, -1.78425562],
[0.41080803, -0.56567699, -0.72781092],
[-0.65718418, -0.14466463, 0.63984287],
[0.20309605, 0.40480086, -1.57559848],
[0.85628128, -0.51614553, -1.15904427],
[-1.84258616, 0.24096519, -0.04563522],
[-0.53364468, -1.07902217, 0.46823528],
];
#[rustfmt::skip]
let b: [f64; N] = [ 0.39745075, -1.06677043, -1.18272650];
let mut c: [f64; M] = [1.0; M];
gemv_truth(1.0, &a, &b, 0.0, &mut c);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy(&[
0.96151888, -0.36771390, 0.94069099,
2.20621538, -0.16479775, -1.78425562,
0.41080803, -0.56567699, -0.72781092,
-0.65718418, -0.14466463, 0.63984287,
0.20309605, 0.40480086, -1.57559848,
0.85628128, -0.51614553, -1.15904427,
-1.84258616, 0.24096519, -0.04563522,
-0.53364468, -1.07902217, 0.46823528,
]).unwrap();
let b_dev = dev.htod_sync_copy(&b).unwrap();
let mut c_dev = dev.alloc_zeros(M).unwrap();
unsafe {
blas.gemv(
GemvConfig {
trans: sys::cublasOperation_t::CUBLAS_OP_T,
m: N as i32,
n: M as i32,
alpha: 1.0,
lda: N as i32,
incx: 1,
beta: 0.0,
incy: 1,
},
&a_dev,
&b_dev,
&mut c_dev,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for i in 0..M {
assert!((c_host[i] - c[i]).abs() <= 1e-8);
}
}
#[cfg(feature = "f16")]
#[test]
fn test_hgemm() {
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlas::new(dev.clone()).unwrap();
const M: usize = 3;
const K: usize = 4;
const N: usize = 5;
let a: [[half::f16; K]; M] = [
[-0.5944882, 1.8055636, 0.52204555, -0.00397902],
[-0.38346434, -0.38013917, 0.4198623, -0.22479166],
[-1.6661372, -0.4568837, -0.9043474, 0.39125723],
]
.map(|r| r.map(half::f16::from_f32));
let b: [[half::f16; N]; K] = [
[1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938],
[1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096],
[1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629],
[3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792],
]
.map(|r| r.map(half::f16::from_f32));
let mut c: [[half::f16; N]; M] = [[0.0; N]; M].map(|r| r.map(half::f16::from_f32));
gemm_truth(
half::f16::from_f32(1.0),
&a,
&b,
half::f16::from_f32(0.0),
&mut c,
);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<half::f16>(&[
-0.5944882, 1.8055636, 0.52204555, -0.00397902,
-0.38346434, -0.38013917, 0.4198623, -0.22479166,
-1.6661372, -0.4568837, -0.9043474, 0.39125723,
].map(half::f16::from_f32)).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<half::f16>(&[
1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938,
1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096,
1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629,
3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792,
].map(half::f16::from_f32)).unwrap();
let mut c_dev = dev.alloc_zeros::<half::f16>(M * N).unwrap();
unsafe {
blas.gemm(
GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: N as i32,
n: M as i32,
k: K as i32,
alpha: half::f16::from_f32(1.0),
lda: N as i32,
ldb: K as i32,
beta: half::f16::from_f32(0.0),
ldc: N as i32,
},
&b_dev,
&a_dev,
&mut c_dev,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
let found = c_host[m * N + n];
let expected = c[m][n];
assert!(
(found - expected) <= half::f16::from_f32(1e-2),
"found={found:?}, expected={expected:?}"
);
}
}
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<half::bf16>(&[
-0.5944882, 1.8055636, 0.52204555, -0.00397902,
-0.38346434, -0.38013917, 0.4198623, -0.22479166,
-1.6661372, -0.4568837, -0.9043474, 0.39125723,
].map(half::bf16::from_f32)).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<half::bf16>(&[
1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938,
1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096,
1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629,
3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792,
].map(half::bf16::from_f32)).unwrap();
let mut c_dev = dev.alloc_zeros::<half::bf16>(M * N).unwrap();
unsafe {
blas.gemm(
GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: N as i32,
n: M as i32,
k: K as i32,
alpha: half::bf16::from_f32(1.0),
lda: N as i32,
ldb: K as i32,
beta: half::bf16::from_f32(0.0),
ldc: N as i32,
},
&b_dev,
&a_dev,
&mut c_dev,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
let found = c_host[m * N + n];
let expected = c[m][n];
assert!(
(half::bf16::to_f32(found) - half::f16::to_f32(expected)) <= 1e-2,
"found={found:?}, expected={expected:?}"
);
}
}
}
#[test]
fn test_sgemm() {
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlas::new(dev.clone()).unwrap();
const M: usize = 3;
const K: usize = 4;
const N: usize = 5;
let a: [[f32; K]; M] = [
[-0.5944882, 1.8055636, 0.52204555, -0.00397902],
[-0.38346434, -0.38013917, 0.4198623, -0.22479166],
[-1.6661372, -0.4568837, -0.9043474, 0.39125723],
];
let b: [[f32; N]; K] = [
[1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938],
[1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096],
[1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629],
[3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792],
];
let mut c: [[f32; N]; M] = [[0.0; N]; M];
gemm_truth(1.0, &a, &b, 0.0, &mut c);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<f32>(&[
-0.5944882, 1.8055636, 0.52204555, -0.00397902,
-0.38346434, -0.38013917, 0.4198623, -0.22479166,
-1.6661372, -0.4568837, -0.9043474, 0.39125723,
]).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<f32>(&[
1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938,
1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096,
1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629,
3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792,
]).unwrap();
let mut c_dev = dev.alloc_zeros::<f32>(M * N).unwrap();
unsafe {
blas.gemm(
GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: N as i32,
n: M as i32,
k: K as i32,
alpha: 1.0,
lda: N as i32,
ldb: K as i32,
beta: 0.0,
ldc: N as i32,
},
&b_dev,
&a_dev,
&mut c_dev,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
assert!((c_host[m * N + n] - c[m][n]) <= 1e-6);
}
}
}
#[test]
fn test_dgemm() {
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlas::new(dev.clone()).unwrap();
const M: usize = 4;
const K: usize = 3;
const N: usize = 2;
let a: [[f64; K]; M] = [
[-0.70925030, -1.01357541, -0.64827034],
[2.18493467, -0.61584842, -1.43844327],
[-1.34792593, 0.68840750, -0.48057214],
[1.22180992, 1.16245157, 0.01253436],
];
let b: [[f64; N]; K] = [
[-0.72735474, 1.35931170],
[1.71798307, -0.13296247],
[0.26855612, -1.95189980],
];
let mut c: [[f64; N]; M] = [[0.0; N]; M];
gemm_truth(1.0, &a, &b, 0.0, &mut c);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<f64>(&[
-0.70925030, -1.01357541, -0.64827034,
2.18493467, -0.61584842, -1.43844327,
-1.34792593, 0.68840750, -0.48057214,
1.22180992, 1.16245157, 0.01253436,
]).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<f64>(&[
-0.72735474, 1.35931170,
1.71798307, -0.13296247,
0.26855612, -1.95189980,
]).unwrap();
let mut c_dev = dev.alloc_zeros::<f64>(M * N).unwrap();
unsafe {
blas.gemm(
GemmConfig {
transa: sys::cublasOperation_t::CUBLAS_OP_N,
transb: sys::cublasOperation_t::CUBLAS_OP_N,
m: N as i32,
n: M as i32,
k: K as i32,
alpha: 1.0,
lda: N as i32,
ldb: K as i32,
beta: 0.0,
ldc: N as i32,
},
&b_dev,
&a_dev,
&mut c_dev,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
assert!((c_host[m * N + n] - c[m][n]) <= 1e-10);
}
}
}
}
/* automatically generated by rust-bindgen 0.60.1 */
#[repr(C)]
#[repr(align(8))]
#[derive(Debug, Default, Copy, Clone, PartialOrd, PartialEq)]
pub struct float2 {
pub x: f32,
pub y: f32,
}
#[test]
fn bindgen_test_layout_float2() {
assert_eq!(
::core::mem::size_of::<float2>(),
8usize,
concat!("Size of: ", stringify!(float2))
);
assert_eq!(
::core::mem::align_of::<float2>(),
8usize,
concat!("Alignment of ", stringify!(float2))
);
fn test_field_x() {
assert_eq!(
unsafe {
let uninit = ::core::mem::MaybeUninit::<float2>::uninit();
let ptr = uninit.as_ptr();
::core::ptr::addr_of!((*ptr).x) as usize - ptr as usize
},
0usize,
concat!("Offset of field: ", stringify!(float2), "::", stringify!(x))
);
}
test_field_x();
fn test_field_y() {
assert_eq!(
unsafe {
let uninit = ::core::mem::MaybeUninit::<float2>::uninit();
let ptr = uninit.as_ptr();
::core::ptr::addr_of!((*ptr).y) as usize - ptr as usize
},
4usize,
concat!("Offset of field: ", stringify!(float2), "::", stringify!(y))
);
}
test_field_y();
}
#[repr(C)]
#[repr(align(16))]
#[derive(Debug, Default, Copy, Clone, PartialOrd, PartialEq)]
pub struct double2 {
pub x: f64,
pub y: f64,
}
#[test]
fn bindgen_test_layout_double2() {
assert_eq!(
::core::mem::size_of::<double2>(),
16usize,
concat!("Size of: ", stringify!(double2))
);
assert_eq!(
::core::mem::align_of::<double2>(),
16usize,
concat!("Alignment of ", stringify!(double2))
);
fn test_field_x() {
assert_eq!(
unsafe {
let uninit = ::core::mem::MaybeUninit::<double2>::uninit();
let ptr = uninit.as_ptr();
::core::ptr::addr_of!((*ptr).x) as usize - ptr as usize
},
0usize,
concat!(
"Offset of field: ",
stringify!(double2),
"::",
stringify!(x)
)
);
}
test_field_x();
fn test_field_y() {
assert_eq!(
unsafe {
let uninit = ::core::mem::MaybeUninit::<double2>::uninit();
let ptr = uninit.as_ptr();
::core::ptr::addr_of!((*ptr).y) as usize - ptr as usize
},
8usize,
concat!(
"Offset of field: ",
stringify!(double2),
"::",
stringify!(y)
)
);
}
test_field_y();
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct CUstream_st {
_unused: [u8; 0],
}
pub type cudaStream_t = *mut CUstream_st;
pub type cuFloatComplex = float2;
pub type cuDoubleComplex = double2;
pub type cuComplex = cuFloatComplex;
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cudaDataType_t {
CUDA_R_16F = 2,
CUDA_C_16F = 6,
CUDA_R_16BF = 14,
CUDA_C_16BF = 15,
CUDA_R_32F = 0,
CUDA_C_32F = 4,
CUDA_R_64F = 1,
CUDA_C_64F = 5,
CUDA_R_4I = 16,
CUDA_C_4I = 17,
CUDA_R_4U = 18,
CUDA_C_4U = 19,
CUDA_R_8I = 3,
CUDA_C_8I = 7,
CUDA_R_8U = 8,
CUDA_C_8U = 9,
CUDA_R_16I = 20,
CUDA_C_16I = 21,
CUDA_R_16U = 22,
CUDA_C_16U = 23,
CUDA_R_32I = 10,
CUDA_C_32I = 11,
CUDA_R_32U = 12,
CUDA_C_32U = 13,
CUDA_R_64I = 24,
CUDA_C_64I = 25,
CUDA_R_64U = 26,
CUDA_C_64U = 27,
}
pub use self::cudaDataType_t as cudaDataType;
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum libraryPropertyType_t {
MAJOR_VERSION = 0,
MINOR_VERSION = 1,
PATCH_LEVEL = 2,
}
pub use self::libraryPropertyType_t as libraryPropertyType;
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasStatus_t {
CUBLAS_STATUS_SUCCESS = 0,
CUBLAS_STATUS_NOT_INITIALIZED = 1,
CUBLAS_STATUS_ALLOC_FAILED = 3,
CUBLAS_STATUS_INVALID_VALUE = 7,
CUBLAS_STATUS_ARCH_MISMATCH = 8,
CUBLAS_STATUS_MAPPING_ERROR = 11,
CUBLAS_STATUS_EXECUTION_FAILED = 13,
CUBLAS_STATUS_INTERNAL_ERROR = 14,
CUBLAS_STATUS_NOT_SUPPORTED = 15,
CUBLAS_STATUS_LICENSE_ERROR = 16,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasFillMode_t {
CUBLAS_FILL_MODE_LOWER = 0,
CUBLAS_FILL_MODE_UPPER = 1,
CUBLAS_FILL_MODE_FULL = 2,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasDiagType_t {
CUBLAS_DIAG_NON_UNIT = 0,
CUBLAS_DIAG_UNIT = 1,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasSideMode_t {
CUBLAS_SIDE_LEFT = 0,
CUBLAS_SIDE_RIGHT = 1,
}
impl cublasOperation_t {
pub const CUBLAS_OP_HERMITAN: cublasOperation_t = cublasOperation_t::CUBLAS_OP_C;
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasOperation_t {
CUBLAS_OP_N = 0,
CUBLAS_OP_T = 1,
CUBLAS_OP_C = 2,
CUBLAS_OP_CONJG = 3,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasPointerMode_t {
CUBLAS_POINTER_MODE_HOST = 0,
CUBLAS_POINTER_MODE_DEVICE = 1,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasAtomicsMode_t {
CUBLAS_ATOMICS_NOT_ALLOWED = 0,
CUBLAS_ATOMICS_ALLOWED = 1,
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_DEFAULT: cublasGemmAlgo_t = cublasGemmAlgo_t::CUBLAS_GEMM_DFALT;
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_DFALT_TENSOR_OP: cublasGemmAlgo_t =
cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#[repr(i32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasGemmAlgo_t {
CUBLAS_GEMM_DFALT = -1,
CUBLAS_GEMM_ALGO0 = 0,
CUBLAS_GEMM_ALGO1 = 1,
CUBLAS_GEMM_ALGO2 = 2,
CUBLAS_GEMM_ALGO3 = 3,
CUBLAS_GEMM_ALGO4 = 4,
CUBLAS_GEMM_ALGO5 = 5,
CUBLAS_GEMM_ALGO6 = 6,
CUBLAS_GEMM_ALGO7 = 7,
CUBLAS_GEMM_ALGO8 = 8,
CUBLAS_GEMM_ALGO9 = 9,
CUBLAS_GEMM_ALGO10 = 10,
CUBLAS_GEMM_ALGO11 = 11,
CUBLAS_GEMM_ALGO12 = 12,
CUBLAS_GEMM_ALGO13 = 13,
CUBLAS_GEMM_ALGO14 = 14,
CUBLAS_GEMM_ALGO15 = 15,
CUBLAS_GEMM_ALGO16 = 16,
CUBLAS_GEMM_ALGO17 = 17,
CUBLAS_GEMM_ALGO18 = 18,
CUBLAS_GEMM_ALGO19 = 19,
CUBLAS_GEMM_ALGO20 = 20,
CUBLAS_GEMM_ALGO21 = 21,
CUBLAS_GEMM_ALGO22 = 22,
CUBLAS_GEMM_ALGO23 = 23,
CUBLAS_GEMM_DEFAULT_TENSOR_OP = 99,
CUBLAS_GEMM_ALGO0_TENSOR_OP = 100,
CUBLAS_GEMM_ALGO1_TENSOR_OP = 101,
CUBLAS_GEMM_ALGO2_TENSOR_OP = 102,
CUBLAS_GEMM_ALGO3_TENSOR_OP = 103,
CUBLAS_GEMM_ALGO4_TENSOR_OP = 104,
CUBLAS_GEMM_ALGO5_TENSOR_OP = 105,
CUBLAS_GEMM_ALGO6_TENSOR_OP = 106,
CUBLAS_GEMM_ALGO7_TENSOR_OP = 107,
CUBLAS_GEMM_ALGO8_TENSOR_OP = 108,
CUBLAS_GEMM_ALGO9_TENSOR_OP = 109,
CUBLAS_GEMM_ALGO10_TENSOR_OP = 110,
CUBLAS_GEMM_ALGO11_TENSOR_OP = 111,
CUBLAS_GEMM_ALGO12_TENSOR_OP = 112,
CUBLAS_GEMM_ALGO13_TENSOR_OP = 113,
CUBLAS_GEMM_ALGO14_TENSOR_OP = 114,
CUBLAS_GEMM_ALGO15_TENSOR_OP = 115,
}
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasMath_t {
CUBLAS_DEFAULT_MATH = 0,
CUBLAS_TENSOR_OP_MATH = 1,
CUBLAS_PEDANTIC_MATH = 2,
CUBLAS_TF32_TENSOR_OP_MATH = 3,
CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION = 16,
}
pub use self::cudaDataType as cublasDataType_t;
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum cublasComputeType_t {
CUBLAS_COMPUTE_16F = 64,
CUBLAS_COMPUTE_16F_PEDANTIC = 65,
CUBLAS_COMPUTE_32F = 68,
CUBLAS_COMPUTE_32F_PEDANTIC = 69,
CUBLAS_COMPUTE_32F_FAST_16F = 74,
CUBLAS_COMPUTE_32F_FAST_16BF = 75,
CUBLAS_COMPUTE_32F_FAST_TF32 = 77,
CUBLAS_COMPUTE_64F = 70,
CUBLAS_COMPUTE_64F_PEDANTIC = 71,
CUBLAS_COMPUTE_32I = 72,
CUBLAS_COMPUTE_32I_PEDANTIC = 73,
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cublasContext {
_unused: [u8; 0],
}
pub type cublasHandle_t = *mut cublasContext;
extern "C" {
pub fn cublasCreate_v2(handle: *mut cublasHandle_t) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDestroy_v2(handle: cublasHandle_t) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetVersion_v2(
handle: cublasHandle_t,
version: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetProperty(
type_: libraryPropertyType,
value: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetCudartVersion() -> usize;
}
extern "C" {
pub fn cublasSetWorkspace_v2(
handle: cublasHandle_t,
workspace: *mut ::core::ffi::c_void,
workspaceSizeInBytes: usize,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetStream_v2(handle: cublasHandle_t, streamId: cudaStream_t) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetStream_v2(
handle: cublasHandle_t,
streamId: *mut cudaStream_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetPointerMode_v2(
handle: cublasHandle_t,
mode: *mut cublasPointerMode_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetPointerMode_v2(
handle: cublasHandle_t,
mode: cublasPointerMode_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetAtomicsMode(
handle: cublasHandle_t,
mode: *mut cublasAtomicsMode_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetAtomicsMode(
handle: cublasHandle_t,
mode: cublasAtomicsMode_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetMathMode(handle: cublasHandle_t, mode: *mut cublasMath_t) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetMathMode(handle: cublasHandle_t, mode: cublasMath_t) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetSmCountTarget(
handle: cublasHandle_t,
smCountTarget: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetSmCountTarget(
handle: cublasHandle_t,
smCountTarget: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetStatusName(status: cublasStatus_t) -> *const core::ffi::c_char;
}
extern "C" {
pub fn cublasGetStatusString(status: cublasStatus_t) -> *const core::ffi::c_char;
}
pub type cublasLogCallback =
::core::option::Option<unsafe extern "C" fn(msg: *const core::ffi::c_char)>;
extern "C" {
pub fn cublasLoggerConfigure(
logIsOn: core::ffi::c_int,
logToStdOut: core::ffi::c_int,
logToStdErr: core::ffi::c_int,
logFileName: *const core::ffi::c_char,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetLoggerCallback(userCallback: cublasLogCallback) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetLoggerCallback(userCallback: *mut cublasLogCallback) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetVector(
n: core::ffi::c_int,
elemSize: core::ffi::c_int,
x: *const ::core::ffi::c_void,
incx: core::ffi::c_int,
devicePtr: *mut ::core::ffi::c_void,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetVector(
n: core::ffi::c_int,
elemSize: core::ffi::c_int,
x: *const ::core::ffi::c_void,
incx: core::ffi::c_int,
y: *mut ::core::ffi::c_void,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetMatrix(
rows: core::ffi::c_int,
cols: core::ffi::c_int,
elemSize: core::ffi::c_int,
A: *const ::core::ffi::c_void,
lda: core::ffi::c_int,
B: *mut ::core::ffi::c_void,
ldb: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetMatrix(
rows: core::ffi::c_int,
cols: core::ffi::c_int,
elemSize: core::ffi::c_int,
A: *const ::core::ffi::c_void,
lda: core::ffi::c_int,
B: *mut ::core::ffi::c_void,
ldb: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetVectorAsync(
n: core::ffi::c_int,
elemSize: core::ffi::c_int,
hostPtr: *const ::core::ffi::c_void,
incx: core::ffi::c_int,
devicePtr: *mut ::core::ffi::c_void,
incy: core::ffi::c_int,
stream: cudaStream_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetVectorAsync(
n: core::ffi::c_int,
elemSize: core::ffi::c_int,
devicePtr: *const ::core::ffi::c_void,
incx: core::ffi::c_int,
hostPtr: *mut ::core::ffi::c_void,
incy: core::ffi::c_int,
stream: cudaStream_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSetMatrixAsync(
rows: core::ffi::c_int,
cols: core::ffi::c_int,
elemSize: core::ffi::c_int,
A: *const ::core::ffi::c_void,
lda: core::ffi::c_int,
B: *mut ::core::ffi::c_void,
ldb: core::ffi::c_int,
stream: cudaStream_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGetMatrixAsync(
rows: core::ffi::c_int,
cols: core::ffi::c_int,
elemSize: core::ffi::c_int,
A: *const ::core::ffi::c_void,
lda: core::ffi::c_int,
B: *mut ::core::ffi::c_void,
ldb: core::ffi::c_int,
stream: cudaStream_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasXerbla(srName: *const core::ffi::c_char, info: core::ffi::c_int);
}
extern "C" {
pub fn cublasNrm2Ex(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
result: *mut ::core::ffi::c_void,
resultType: cudaDataType,
executionType: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSnrm2_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
result: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDnrm2_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
result: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasScnrm2_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
result: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDznrm2_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
result: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDotEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *const ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
result: *mut ::core::ffi::c_void,
resultType: cudaDataType,
executionType: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDotcEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *const ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
result: *mut ::core::ffi::c_void,
resultType: cudaDataType,
executionType: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSdot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
y: *const f32,
incy: core::ffi::c_int,
result: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDdot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
y: *const f64,
incy: core::ffi::c_int,
result: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCdotu_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
result: *mut cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCdotc_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
result: *mut cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZdotu_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
result: *mut cuDoubleComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZdotc_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
result: *mut cuDoubleComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasScalEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const ::core::ffi::c_void,
alphaType: cudaDataType,
x: *mut ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
executionType: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSscal_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDscal_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCscal_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsscal_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZscal_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZdscal_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasAxpyEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const ::core::ffi::c_void,
alphaType: cudaDataType,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *mut ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
executiontype: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSaxpy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const f32,
incx: core::ffi::c_int,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDaxpy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const f64,
incx: core::ffi::c_int,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCaxpy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZaxpy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCopyEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *mut ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasScopy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDcopy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCcopy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZcopy_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSswap_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDswap_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCswap_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZswap_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSwapEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *mut ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIsamax_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIdamax_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIcamax_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIzamax_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIamaxEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIsamin_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIdamin_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIcamin_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIzamin_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasIaminEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
result: *mut core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasAsumEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
result: *mut ::core::ffi::c_void,
resultType: cudaDataType,
executiontype: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSasum_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
result: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDasum_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
result: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasScasum_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
result: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDzasum_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
result: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSrot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
y: *mut f32,
incy: core::ffi::c_int,
c: *const f32,
s: *const f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDrot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
y: *mut f64,
incy: core::ffi::c_int,
c: *const f64,
s: *const f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCrot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
y: *mut cuComplex,
incy: core::ffi::c_int,
c: *const f32,
s: *const cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsrot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
y: *mut cuComplex,
incy: core::ffi::c_int,
c: *const f32,
s: *const f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZrot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
c: *const f64,
s: *const cuDoubleComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZdrot_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
c: *const f64,
s: *const f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasRotEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *mut ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
c: *const ::core::ffi::c_void,
s: *const ::core::ffi::c_void,
csType: cudaDataType,
executiontype: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSrotg_v2(
handle: cublasHandle_t,
a: *mut f32,
b: *mut f32,
c: *mut f32,
s: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDrotg_v2(
handle: cublasHandle_t,
a: *mut f64,
b: *mut f64,
c: *mut f64,
s: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCrotg_v2(
handle: cublasHandle_t,
a: *mut cuComplex,
b: *mut cuComplex,
c: *mut f32,
s: *mut cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZrotg_v2(
handle: cublasHandle_t,
a: *mut cuDoubleComplex,
b: *mut cuDoubleComplex,
c: *mut f64,
s: *mut cuDoubleComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasRotgEx(
handle: cublasHandle_t,
a: *mut ::core::ffi::c_void,
b: *mut ::core::ffi::c_void,
abType: cudaDataType,
c: *mut ::core::ffi::c_void,
s: *mut ::core::ffi::c_void,
csType: cudaDataType,
executiontype: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSrotm_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
y: *mut f32,
incy: core::ffi::c_int,
param: *const f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDrotm_v2(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
y: *mut f64,
incy: core::ffi::c_int,
param: *const f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasRotmEx(
handle: cublasHandle_t,
n: core::ffi::c_int,
x: *mut ::core::ffi::c_void,
xType: cudaDataType,
incx: core::ffi::c_int,
y: *mut ::core::ffi::c_void,
yType: cudaDataType,
incy: core::ffi::c_int,
param: *const ::core::ffi::c_void,
paramType: cudaDataType,
executiontype: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSrotmg_v2(
handle: cublasHandle_t,
d1: *mut f32,
d2: *mut f32,
x1: *mut f32,
y1: *const f32,
param: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDrotmg_v2(
handle: cublasHandle_t,
d1: *mut f64,
d2: *mut f64,
x1: *mut f64,
y1: *const f64,
param: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasRotmgEx(
handle: cublasHandle_t,
d1: *mut ::core::ffi::c_void,
d1Type: cudaDataType,
d2: *mut ::core::ffi::c_void,
d2Type: cudaDataType,
x1: *mut ::core::ffi::c_void,
x1Type: cudaDataType,
y1: *const ::core::ffi::c_void,
y1Type: cudaDataType,
param: *mut ::core::ffi::c_void,
paramType: cudaDataType,
executiontype: cudaDataType,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
beta: *const f32,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgemv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
beta: *const f64,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgbmv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
kl: core::ffi::c_int,
ku: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
beta: *const f32,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgbmv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
kl: core::ffi::c_int,
ku: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
beta: *const f64,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgbmv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
kl: core::ffi::c_int,
ku: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgbmv_v2(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
kl: core::ffi::c_int,
ku: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStrmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const f32,
lda: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtrmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const f64,
lda: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtrmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtrmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const f32,
lda: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const f64,
lda: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStpmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const f32,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtpmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const f64,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtpmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const cuComplex,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtpmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const cuDoubleComplex,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStrsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const f32,
lda: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtrsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const f64,
lda: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtrsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtrsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStpsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const f32,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtpsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const f64,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtpsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const cuComplex,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtpsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
AP: *const cuDoubleComplex,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStbsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const f32,
lda: core::ffi::c_int,
x: *mut f32,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtbsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const f64,
lda: core::ffi::c_int,
x: *mut f64,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtbsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *mut cuComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtbsv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *mut cuDoubleComplex,
incx: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsymv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
beta: *const f32,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsymv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
beta: *const f64,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsymv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsymv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasChemv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZhemv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
beta: *const f32,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
beta: *const f64,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasChbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZhbmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSspmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
AP: *const f32,
x: *const f32,
incx: core::ffi::c_int,
beta: *const f32,
y: *mut f32,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDspmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
AP: *const f64,
x: *const f64,
incx: core::ffi::c_int,
beta: *const f64,
y: *mut f64,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasChpmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
AP: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZhpmv_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
AP: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSger_v2(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
x: *const f32,
incx: core::ffi::c_int,
y: *const f32,
incy: core::ffi::c_int,
A: *mut f32,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDger_v2(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
x: *const f64,
incx: core::ffi::c_int,
y: *const f64,
incy: core::ffi::c_int,
A: *mut f64,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgeru_v2(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgerc_v2(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgeru_v2(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgerc_v2(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsyr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const f32,
incx: core::ffi::c_int,
A: *mut f32,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsyr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const f64,
incx: core::ffi::c_int,
A: *mut f64,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsyr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCher_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const cuComplex,
incx: core::ffi::c_int,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZher_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSspr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const f32,
incx: core::ffi::c_int,
AP: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDspr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const f64,
incx: core::ffi::c_int,
AP: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasChpr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const cuComplex,
incx: core::ffi::c_int,
AP: *mut cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZhpr_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
AP: *mut cuDoubleComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsyr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const f32,
incx: core::ffi::c_int,
y: *const f32,
incy: core::ffi::c_int,
A: *mut f32,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsyr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const f64,
incx: core::ffi::c_int,
y: *const f64,
incy: core::ffi::c_int,
A: *mut f64,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsyr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCher2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZher2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSspr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f32,
x: *const f32,
incx: core::ffi::c_int,
y: *const f32,
incy: core::ffi::c_int,
AP: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDspr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const f64,
x: *const f64,
incx: core::ffi::c_int,
y: *const f64,
incy: core::ffi::c_int,
AP: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasChpr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuComplex,
x: *const cuComplex,
incx: core::ffi::c_int,
y: *const cuComplex,
incy: core::ffi::c_int,
AP: *mut cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZhpr2_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
y: *const cuDoubleComplex,
incy: core::ffi::c_int,
AP: *mut cuDoubleComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemvBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
Aarray: *const *const f32,
lda: core::ffi::c_int,
xarray: *const *const f32,
incx: core::ffi::c_int,
beta: *const f32,
yarray: *const *mut f32,
incy: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgemvBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
Aarray: *const *const f64,
lda: core::ffi::c_int,
xarray: *const *const f64,
incx: core::ffi::c_int,
beta: *const f64,
yarray: *const *mut f64,
incy: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemvBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
Aarray: *const *const cuComplex,
lda: core::ffi::c_int,
xarray: *const *const cuComplex,
incx: core::ffi::c_int,
beta: *const cuComplex,
yarray: *const *mut cuComplex,
incy: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemvBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
Aarray: *const *const cuDoubleComplex,
lda: core::ffi::c_int,
xarray: *const *const cuDoubleComplex,
incx: core::ffi::c_int,
beta: *const cuDoubleComplex,
yarray: *const *mut cuDoubleComplex,
incy: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemvStridedBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
x: *const f32,
incx: core::ffi::c_int,
stridex: core::ffi::c_longlong,
beta: *const f32,
y: *mut f32,
incy: core::ffi::c_int,
stridey: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgemvStridedBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
x: *const f64,
incx: core::ffi::c_int,
stridex: core::ffi::c_longlong,
beta: *const f64,
y: *mut f64,
incy: core::ffi::c_int,
stridey: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemvStridedBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
x: *const cuComplex,
incx: core::ffi::c_int,
stridex: core::ffi::c_longlong,
beta: *const cuComplex,
y: *mut cuComplex,
incy: core::ffi::c_int,
stridey: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemvStridedBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
stridex: core::ffi::c_longlong,
beta: *const cuDoubleComplex,
y: *mut cuDoubleComplex,
incy: core::ffi::c_int,
stridey: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemm_v2(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
B: *const f32,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgemm_v2(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
B: *const f64,
ldb: core::ffi::c_int,
beta: *const f64,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemm_v2(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemm3m(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemm3mEx(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
B: *const ::core::ffi::c_void,
Btype: cudaDataType,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemm_v2(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemm3m(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemmEx(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
B: *const ::core::ffi::c_void,
Btype: cudaDataType,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGemmEx(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const ::core::ffi::c_void,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
B: *const ::core::ffi::c_void,
Btype: cudaDataType,
ldb: core::ffi::c_int,
beta: *const ::core::ffi::c_void,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemmEx(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
B: *const ::core::ffi::c_void,
Btype: cudaDataType,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasUint8gemmBias(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
transc: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
A: *const core::ffi::c_uchar,
A_bias: core::ffi::c_int,
lda: core::ffi::c_int,
B: *const core::ffi::c_uchar,
B_bias: core::ffi::c_int,
ldb: core::ffi::c_int,
C: *mut core::ffi::c_uchar,
C_bias: core::ffi::c_int,
ldc: core::ffi::c_int,
C_mult: core::ffi::c_int,
C_shift: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsyrk_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
beta: *const f32,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsyrk_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
beta: *const f64,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyrk_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsyrk_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyrkEx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
beta: *const cuComplex,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyrk3mEx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
beta: *const cuComplex,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCherk_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const cuComplex,
lda: core::ffi::c_int,
beta: *const f32,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZherk_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
beta: *const f64,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCherkEx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
beta: *const f32,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCherk3mEx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
beta: *const f32,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsyr2k_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
B: *const f32,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsyr2k_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
B: *const f64,
ldb: core::ffi::c_int,
beta: *const f64,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyr2k_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsyr2k_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCher2k_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZher2k_v2(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const f64,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsyrkx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
B: *const f32,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsyrkx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
B: *const f64,
ldb: core::ffi::c_int,
beta: *const f64,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsyrkx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsyrkx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCherkx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZherkx(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const f64,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSsymm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
B: *const f32,
ldb: core::ffi::c_int,
beta: *const f32,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDsymm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
B: *const f64,
ldb: core::ffi::c_int,
beta: *const f64,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCsymm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZsymm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasChemm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZhemm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStrsm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
B: *mut f32,
ldb: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtrsm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
B: *mut f64,
ldb: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtrsm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *mut cuComplex,
ldb: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtrsm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *mut cuDoubleComplex,
ldb: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStrmm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
B: *const f32,
ldb: core::ffi::c_int,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtrmm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
B: *const f64,
ldb: core::ffi::c_int,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtrmm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
B: *const cuComplex,
ldb: core::ffi::c_int,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtrmm_v2(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemmBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
Aarray: *const *const f32,
lda: core::ffi::c_int,
Barray: *const *const f32,
ldb: core::ffi::c_int,
beta: *const f32,
Carray: *const *mut f32,
ldc: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgemmBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
Aarray: *const *const f64,
lda: core::ffi::c_int,
Barray: *const *const f64,
ldb: core::ffi::c_int,
beta: *const f64,
Carray: *const *mut f64,
ldc: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemmBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
Aarray: *const *const cuComplex,
lda: core::ffi::c_int,
Barray: *const *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
Carray: *const *mut cuComplex,
ldc: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemm3mBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
Aarray: *const *const cuComplex,
lda: core::ffi::c_int,
Barray: *const *const cuComplex,
ldb: core::ffi::c_int,
beta: *const cuComplex,
Carray: *const *mut cuComplex,
ldc: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemmBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
Aarray: *const *const cuDoubleComplex,
lda: core::ffi::c_int,
Barray: *const *const cuDoubleComplex,
ldb: core::ffi::c_int,
beta: *const cuDoubleComplex,
Carray: *const *mut cuDoubleComplex,
ldc: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGemmBatchedEx(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const ::core::ffi::c_void,
Aarray: *const *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
Barray: *const *const ::core::ffi::c_void,
Btype: cudaDataType,
ldb: core::ffi::c_int,
beta: *const ::core::ffi::c_void,
Carray: *const *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
batchCount: core::ffi::c_int,
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasGemmStridedBatchedEx(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const ::core::ffi::c_void,
A: *const ::core::ffi::c_void,
Atype: cudaDataType,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
B: *const ::core::ffi::c_void,
Btype: cudaDataType,
ldb: core::ffi::c_int,
strideB: core::ffi::c_longlong,
beta: *const ::core::ffi::c_void,
C: *mut ::core::ffi::c_void,
Ctype: cudaDataType,
ldc: core::ffi::c_int,
strideC: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgemmStridedBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
B: *const f32,
ldb: core::ffi::c_int,
strideB: core::ffi::c_longlong,
beta: *const f32,
C: *mut f32,
ldc: core::ffi::c_int,
strideC: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgemmStridedBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
B: *const f64,
ldb: core::ffi::c_int,
strideB: core::ffi::c_longlong,
beta: *const f64,
C: *mut f64,
ldc: core::ffi::c_int,
strideC: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemmStridedBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
B: *const cuComplex,
ldb: core::ffi::c_int,
strideB: core::ffi::c_longlong,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
strideC: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgemm3mStridedBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
B: *const cuComplex,
ldb: core::ffi::c_int,
strideB: core::ffi::c_longlong,
beta: *const cuComplex,
C: *mut cuComplex,
ldc: core::ffi::c_int,
strideC: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgemmStridedBatched(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
k: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
strideA: core::ffi::c_longlong,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
strideB: core::ffi::c_longlong,
beta: *const cuDoubleComplex,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
strideC: core::ffi::c_longlong,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgeam(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const f32,
lda: core::ffi::c_int,
beta: *const f32,
B: *const f32,
ldb: core::ffi::c_int,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgeam(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const f64,
lda: core::ffi::c_int,
beta: *const f64,
B: *const f64,
ldb: core::ffi::c_int,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgeam(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const cuComplex,
lda: core::ffi::c_int,
beta: *const cuComplex,
B: *const cuComplex,
ldb: core::ffi::c_int,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgeam(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
beta: *const cuDoubleComplex,
B: *const cuDoubleComplex,
ldb: core::ffi::c_int,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgetrfBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *mut f32,
lda: core::ffi::c_int,
P: *mut core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgetrfBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *mut f64,
lda: core::ffi::c_int,
P: *mut core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgetrfBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *mut cuComplex,
lda: core::ffi::c_int,
P: *mut core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgetrfBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *mut cuDoubleComplex,
lda: core::ffi::c_int,
P: *mut core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgetriBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const f32,
lda: core::ffi::c_int,
P: *const core::ffi::c_int,
C: *const *mut f32,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgetriBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const f64,
lda: core::ffi::c_int,
P: *const core::ffi::c_int,
C: *const *mut f64,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgetriBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const cuComplex,
lda: core::ffi::c_int,
P: *const core::ffi::c_int,
C: *const *mut cuComplex,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgetriBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const cuDoubleComplex,
lda: core::ffi::c_int,
P: *const core::ffi::c_int,
C: *const *mut cuDoubleComplex,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgetrsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *const f32,
lda: core::ffi::c_int,
devIpiv: *const core::ffi::c_int,
Barray: *const *mut f32,
ldb: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgetrsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *const f64,
lda: core::ffi::c_int,
devIpiv: *const core::ffi::c_int,
Barray: *const *mut f64,
ldb: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgetrsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *const cuComplex,
lda: core::ffi::c_int,
devIpiv: *const core::ffi::c_int,
Barray: *const *mut cuComplex,
ldb: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgetrsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *const cuDoubleComplex,
lda: core::ffi::c_int,
devIpiv: *const core::ffi::c_int,
Barray: *const *mut cuDoubleComplex,
ldb: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStrsmBatched(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f32,
A: *const *const f32,
lda: core::ffi::c_int,
B: *const *mut f32,
ldb: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtrsmBatched(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const f64,
A: *const *const f64,
lda: core::ffi::c_int,
B: *const *mut f64,
ldb: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtrsmBatched(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuComplex,
A: *const *const cuComplex,
lda: core::ffi::c_int,
B: *const *mut cuComplex,
ldb: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtrsmBatched(
handle: cublasHandle_t,
side: cublasSideMode_t,
uplo: cublasFillMode_t,
trans: cublasOperation_t,
diag: cublasDiagType_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
alpha: *const cuDoubleComplex,
A: *const *const cuDoubleComplex,
lda: core::ffi::c_int,
B: *const *mut cuDoubleComplex,
ldb: core::ffi::c_int,
batchCount: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSmatinvBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const f32,
lda: core::ffi::c_int,
Ainv: *const *mut f32,
lda_inv: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDmatinvBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const f64,
lda: core::ffi::c_int,
Ainv: *const *mut f64,
lda_inv: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCmatinvBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const cuComplex,
lda: core::ffi::c_int,
Ainv: *const *mut cuComplex,
lda_inv: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZmatinvBatched(
handle: cublasHandle_t,
n: core::ffi::c_int,
A: *const *const cuDoubleComplex,
lda: core::ffi::c_int,
Ainv: *const *mut cuDoubleComplex,
lda_inv: core::ffi::c_int,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgeqrfBatched(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
Aarray: *const *mut f32,
lda: core::ffi::c_int,
TauArray: *const *mut f32,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgeqrfBatched(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
Aarray: *const *mut f64,
lda: core::ffi::c_int,
TauArray: *const *mut f64,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgeqrfBatched(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
Aarray: *const *mut cuComplex,
lda: core::ffi::c_int,
TauArray: *const *mut cuComplex,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgeqrfBatched(
handle: cublasHandle_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
Aarray: *const *mut cuDoubleComplex,
lda: core::ffi::c_int,
TauArray: *const *mut cuDoubleComplex,
info: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSgelsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *mut f32,
lda: core::ffi::c_int,
Carray: *const *mut f32,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
devInfoArray: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDgelsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *mut f64,
lda: core::ffi::c_int,
Carray: *const *mut f64,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
devInfoArray: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCgelsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *mut cuComplex,
lda: core::ffi::c_int,
Carray: *const *mut cuComplex,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
devInfoArray: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZgelsBatched(
handle: cublasHandle_t,
trans: cublasOperation_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
nrhs: core::ffi::c_int,
Aarray: *const *mut cuDoubleComplex,
lda: core::ffi::c_int,
Carray: *const *mut cuDoubleComplex,
ldc: core::ffi::c_int,
info: *mut core::ffi::c_int,
devInfoArray: *mut core::ffi::c_int,
batchSize: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasSdgmm(
handle: cublasHandle_t,
mode: cublasSideMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
A: *const f32,
lda: core::ffi::c_int,
x: *const f32,
incx: core::ffi::c_int,
C: *mut f32,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDdgmm(
handle: cublasHandle_t,
mode: cublasSideMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
A: *const f64,
lda: core::ffi::c_int,
x: *const f64,
incx: core::ffi::c_int,
C: *mut f64,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCdgmm(
handle: cublasHandle_t,
mode: cublasSideMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
A: *const cuComplex,
lda: core::ffi::c_int,
x: *const cuComplex,
incx: core::ffi::c_int,
C: *mut cuComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZdgmm(
handle: cublasHandle_t,
mode: cublasSideMode_t,
m: core::ffi::c_int,
n: core::ffi::c_int,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
x: *const cuDoubleComplex,
incx: core::ffi::c_int,
C: *mut cuDoubleComplex,
ldc: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStpttr(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
AP: *const f32,
A: *mut f32,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtpttr(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
AP: *const f64,
A: *mut f64,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtpttr(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
AP: *const cuComplex,
A: *mut cuComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtpttr(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
AP: *const cuDoubleComplex,
A: *mut cuDoubleComplex,
lda: core::ffi::c_int,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasStrttp(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
A: *const f32,
lda: core::ffi::c_int,
AP: *mut f32,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasDtrttp(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
A: *const f64,
lda: core::ffi::c_int,
AP: *mut f64,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasCtrttp(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
A: *const cuComplex,
lda: core::ffi::c_int,
AP: *mut cuComplex,
) -> cublasStatus_t;
}
extern "C" {
pub fn cublasZtrttp(
handle: cublasHandle_t,
uplo: cublasFillMode_t,
n: core::ffi::c_int,
A: *const cuDoubleComplex,
lda: core::ffi::c_int,
AP: *mut cuDoubleComplex,
) -> cublasStatus_t;
}
#include "cublas_v2.h"
\ No newline at end of file
#!/bin/bash
# Requires rust-bindgen 0.68.1 or superior
set -exu
BINDGEN_EXTRA_CLANG_ARGS="-D__CUDA_BF16_TYPES_EXIST__" \
bindgen \
--allowlist-type="^cublasLt.*" \
--allowlist-var="^cublasLt.*" \
--allowlist-function="^cublasLt.*" \
--default-enum-style=rust \
--no-doc-comments \
--with-derive-default \
--with-derive-eq \
--with-derive-hash \
--with-derive-ord \
--use-core \
wrapper.h -- -I/usr/local/cuda/include \
> sys.rs
pub mod result;
pub mod safe;
#[allow(warnings)]
pub mod sys;
pub use safe::*;
use super::sys;
use crate::cublaslt::sys::cublasLtMatmulAlgo_t;
use core::ffi::c_void;
use core::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CublasError(pub sys::cublasStatus_t);
impl sys::cublasStatus_t {
pub fn result(self) -> Result<(), CublasError> {
match self {
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
_ => Err(CublasError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CublasError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CublasError {}
/// Creates a handle to the cuBLASLT library. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltcreate)
pub fn create_handle() -> Result<sys::cublasLtHandle_t, CublasError> {
let mut handle = MaybeUninit::uninit();
unsafe {
sys::cublasLtCreate(handle.as_mut_ptr()).result()?;
Ok(handle.assume_init())
}
}
/// Destroys a handle previously created with [create_handle()]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltdestroy)
///
/// # Safety
///
/// `handle` must not have been freed already.
pub unsafe fn destroy_handle(handle: sys::cublasLtHandle_t) -> Result<(), CublasError> {
sys::cublasLtDestroy(handle).result()
}
/// Creates a matrix layout descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutcreate)
pub fn create_matrix_layout(
matrix_type: sys::cudaDataType,
rows: u64,
cols: u64,
ld: i64,
) -> Result<sys::cublasLtMatrixLayout_t, CublasError> {
let mut matrix_layout = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatrixLayoutCreate(matrix_layout.as_mut_ptr(), matrix_type, rows, cols, ld)
.result()?;
Ok(matrix_layout.assume_init())
}
}
/// Sets the value of the specified attribute belonging to a previously created matrix layout
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutsetattribute)
///
/// # Safety
/// `matrix_layout` must not have been freed already.
pub unsafe fn set_matrix_layout_attribute(
matrix_layout: sys::cublasLtMatrixLayout_t,
attr: sys::cublasLtMatrixLayoutAttribute_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
//println!("set_matrix_layout_attribute The address of buf is: {:p}", buf);
sys::cublasLtMatrixLayoutSetAttribute(matrix_layout, attr, buf, buf_size).result()
}
/// Destroys a matrix layout previously created with [create_matrix_layout(...)]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutdestroy)
///
/// # Safety
///
/// `matrix_layout` must not have been freed already.
pub unsafe fn destroy_matrix_layout(
matrix_layout: sys::cublasLtMatrixLayout_t,
) -> Result<(), CublasError> {
sys::cublasLtMatrixLayoutDestroy(matrix_layout).result()
}
/// Creates a matrix multiply descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldesccreate)
pub fn create_matmul_desc(
compute_type: sys::cublasComputeType_t,
scale_type: sys::cudaDataType,
) -> Result<sys::cublasLtMatmulDesc_t, CublasError> {
let mut matmul_desc = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatmulDescCreate(matmul_desc.as_mut_ptr(), compute_type, scale_type)
.result()?;
Ok(matmul_desc.assume_init())
}
}
/// Sets the value of the specified attribute belonging to a previously created matrix multiply
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldescsetattribute)
///
/// # Safety
/// `matmul_desc` must not be freed already.
pub unsafe fn set_matmul_desc_attribute(
matmul_desc: sys::cublasLtMatmulDesc_t,
attr: sys::cublasLtMatmulDescAttributes_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
//println!("set_matmul_desc_attribute The address of buf is: {:p}", buf);
sys::cublasLtMatmulDescSetAttribute(matmul_desc, attr, buf, buf_size).result()
}
/// Destroys a matrix multiply descriptor previously created with [create_matmul_desc(...)]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldescdestroy)
///
/// # Safety
///
/// `matmul_desc` must not have been freed already.
pub unsafe fn destroy_matmul_desc(
matmul_desc: sys::cublasLtMatmulDesc_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmulDescDestroy(matmul_desc).result()
}
/// Creates a matrix multiply heuristic search preferences descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencecreate)
pub fn create_matmul_pref() -> Result<sys::cublasLtMatmulPreference_t, CublasError> {
let mut matmul_pref = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatmulPreferenceCreate(matmul_pref.as_mut_ptr()).result()?;
Ok(matmul_pref.assume_init())
}
}
/// Sets the value of the specified attribute belonging to a previously create matrix multiply
/// preferences descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencesetattribute)
///
/// # Safety
/// `matmul_pref` must not have been freed already.
pub unsafe fn set_matmul_pref_attribute(
matmul_pref: sys::cublasLtMatmulPreference_t,
attr: sys::cublasLtMatmulPreferenceAttributes_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatmulPreferenceSetAttribute(matmul_pref, attr, buf, buf_size).result()
}
/// Destroys a matrix multiply preferences descriptor previously created
/// with [create_matmul_pref()]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencedestroy)
///
/// # Safety
///
/// `matmul_pref` must not have been freed already.
pub unsafe fn destroy_matmul_pref(
matmul_pref: sys::cublasLtMatmulPreference_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmulPreferenceDestroy(matmul_pref).result()
}
/// Retrieves the fastest possible algorithm for the matrix multiply operation function
/// given input matrices A, B and C and the output matrix D. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulalgogetheuristic)
///
/// # Safety
/// All the parameters must not have been freed already & must be valid layouts for allocations.
pub unsafe fn get_matmul_algo_heuristic(
handle: sys::cublasLtHandle_t,
matmul_desc: sys::cublasLtMatmulDesc_t,
a_layout: sys::cublasLtMatrixLayout_t,
b_layout: sys::cublasLtMatrixLayout_t,
c_layout: sys::cublasLtMatrixLayout_t,
d_layout: sys::cublasLtMatrixLayout_t,
matmul_pref: sys::cublasLtMatmulPreference_t,
) -> Result<sys::cublasLtMatmulHeuristicResult_t, CublasError> {
//println!("get_matmul_algo_heuristic111");
let mut matmul_heuristic = MaybeUninit::uninit();
let mut algo_count = 0;
//println!("get_matmul_algo_heuristic222");
sys::cublasLtMatmulAlgoGetHeuristic(
handle,
matmul_desc,
a_layout,
b_layout,
c_layout,
d_layout,
matmul_pref,
1, // only select the fastest algo
matmul_heuristic.as_mut_ptr(),
&mut algo_count,
)
.result()?;
//println!("get_matmul_algo_heuristic333 algo_count:{}",algo_count);
if algo_count == 0 {
return Err(CublasError(
sys::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED,
));
}
//println!("get_matmul_algo_heuristic444");
let matmul_heuristic = matmul_heuristic.assume_init();
matmul_heuristic.state.result()?;
//println!("get_matmul_algo_heuristic555");
Ok(matmul_heuristic)
}
/// Computes the matrix multiplication of matrics A and B to produce the output matrix D,
/// according to the following operation: D = alpha*(A*B) + beta*(C)
/// where A, B, and C are input matrices, and alpha and beta are input scalars. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul)
///
/// # Safety
/// All the sys objects can't have been freed already.
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul(
handle: sys::cublasLtHandle_t,
matmul_desc: sys::cublasLtMatmulDesc_t,
alpha: *const c_void,
beta: *const c_void,
a: *const c_void,
a_layout: sys::cublasLtMatrixLayout_t,
b: *const c_void,
b_layout: sys::cublasLtMatrixLayout_t,
c: *const c_void,
c_layout: sys::cublasLtMatrixLayout_t,
d: *mut c_void,
d_layout: sys::cublasLtMatrixLayout_t,
algo: *const cublasLtMatmulAlgo_t,
workspace: *mut c_void,
workspace_size: usize,
stream: sys::cudaStream_t,
) -> Result<(), CublasError> {
//println!("cudarc src/cublaslt/result.rs 240 1");
sys::cublasLtMatmul(
handle,
matmul_desc,
alpha,
a,
a_layout,
b,
b_layout,
beta,
c,
c_layout,
d,
d_layout,
algo,
workspace,
workspace_size,
stream,
)
.result()
}
//! Safe abstractions around [crate::cublaslt::result] for doing matmul.
use super::{result, result::CublasError, sys};
use crate::cublaslt::result::set_matrix_layout_attribute;
use crate::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream};
use crate::driver::{CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError};
use core::ffi::c_int;
use core::mem;
use std::sync::Arc;
/// Wrapper around [sys::cublasLtHandle_t]
///
/// 1. Create with [CudaBlasLT::new()]
/// 2. Execute matmul kernel with matmul. f32 is supported. f16 and bf16 are supported
/// if feature `half` is activated
///
/// Note: This maintains a instance of [`Arc<CudaDevice>`], so will prevent the device
/// from being dropped. Kernels will be launched on the device device default stream.
#[derive(Debug)]
pub struct CudaBlasLT {
handle: sys::cublasLtHandle_t,
workspace: Workspace,
device: Arc<CudaDevice>,
}
unsafe impl Send for CudaBlasLT {}
unsafe impl Sync for CudaBlasLT {}
impl CudaBlasLT {
/// Creates a new cublasLt handle.
pub fn new(device: Arc<CudaDevice>) -> Result<Self, CublasError> {
let handle = result::create_handle()?;
let workspace = Workspace::new(device.clone()).unwrap();
Ok(Self {
handle,
workspace,
device,
})
}
}
impl Drop for CudaBlasLT {
fn drop(&mut self) {
let handle = mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::destroy_handle(handle) }.unwrap();
}
}
}
/// User owned CublasLt workspace buffer.
/// The workspace is initialised following the Nvidia recommendations:
///
/// 1. NVIDIA Hopper Architecture: 32 MiB
/// 2. Other: 4 MiB
#[derive(Debug, Clone)]
pub struct Workspace {
pub(crate) buffer: CudaSlice<u8>,
pub(crate) size: usize,
}
impl Workspace {
/// Creates a CublasLt workspace buffer on the provided device
pub fn new(device: Arc<CudaDevice>) -> Result<Self, DriverError> {
device.bind_to_thread()?;
let major =
device.attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
let workspace_size = if major >= 9 { 33_554_432 } else { 4_194_304 };
let buffer = unsafe { device.alloc::<u8>(workspace_size)? };
Ok(Self {
buffer,
size: workspace_size,
})
}
}
/// Available activation for kernel fusing in matmul
#[derive(Debug, Clone)]
pub enum Activation {
Relu,
Gelu,
}
/// MatrixLayout helper type
struct MatrixLayout {
handle: sys::cublasLtMatrixLayout_t,
}
impl MatrixLayout {
fn new(
matrix_type: sys::cudaDataType,
rows: u64,
cols: u64,
ld: i64,
) -> Result<Self, CublasError> {
let handle = result::create_matrix_layout(matrix_type, rows, cols, ld)?;
//println!("!!!MatrixLayout Address stored in handle: {:p}", handle as *const _);
Ok(Self { handle })
}
fn set_batch(&self, size: c_int, stride: i64) -> Result<(), CublasError> {
unsafe {
// Set batch size
set_matrix_layout_attribute(
self.handle,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
// Set batch stride
set_matrix_layout_attribute(
self.handle,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}
Ok(())
}
}
impl Drop for MatrixLayout {
fn drop(&mut self) {
// panic on failure
unsafe {
result::destroy_matrix_layout(self.handle).expect("Unable to destroy matrix layout")
}
}
}
enum Matrix {
A,
B,
#[allow(dead_code)]
C,
}
/// MatmulDesc helper type
struct MatmulDesc {
handle: sys::cublasLtMatmulDesc_t,
}
impl MatmulDesc {
fn new(
compute_type: sys::cublasComputeType_t,
scale_type: sys::cudaDataType,
) -> Result<Self, CublasError> {
let handle = result::create_matmul_desc(compute_type, scale_type)?;
//println!("!!!MatmulDesc Address stored in handle: {:p}", handle as *const _);
Ok(Self { handle })
}
fn set_transpose(&self, transpose: bool, matrix: Matrix) -> Result<(), CublasError> {
// Set transpose
// 1 == T, 0 == N
let transpose = transpose as i32;
let attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSC,
};
unsafe {
result::set_matmul_desc_attribute(
self.handle,
attr,
(&transpose) as *const _ as *const _,
mem::size_of::<u32>(),
)?;
}
Ok(())
}
// Epilogue system can be leveraged to fuse add and activation operations
fn set_epilogue(
&self,
act: Option<&Activation>,
bias_ptr: Option<&CUdeviceptr>,
stride_bias: Option<i64>,
) -> Result<(), CublasError> {
let epilogue = if let Some(bias_ptr) = bias_ptr {
let epilogue = act
.map(|act| match act {
// Act + bias
Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
})
// Only bias
.unwrap_or(sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS);
// Set bias CUdeviceptr in matmul_desc
unsafe {
result::set_matmul_desc_attribute(
self.handle,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias_ptr as *const CUdeviceptr as *const _,
mem::size_of::<CUdeviceptr>(),
)?;
}
if let Some(stride_bias) = stride_bias {
// Set bias batch stride
unsafe {
result::set_matmul_desc_attribute(
self.handle,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE,
(&stride_bias) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}
}
epilogue
} else if let Some(act) = act {
// Only Act
match act {
Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
}
} else {
// No epilogue
sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT
};
// Set epilogue
unsafe {
result::set_matmul_desc_attribute(
self.handle,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE,
(&epilogue) as *const _ as *const _,
mem::size_of::<sys::cublasLtMatmulDescAttributes_t>(),
)?;
}
Ok(())
}
}
impl Drop for MatmulDesc {
fn drop(&mut self) {
unsafe { result::destroy_matmul_desc(self.handle).expect("Unable to destroy matmul desc") }
}
}
/// MatmulPref helper type
struct MatmulPref {
handle: sys::cublasLtMatmulPreference_t,
}
impl MatmulPref {
fn new() -> Result<Self, CublasError> {
let handle = result::create_matmul_pref()?;
Ok(Self { handle })
}
fn set_workspace_size(&self, size: usize) -> Result<(), CublasError> {
unsafe {
// Set workspace size
result::set_matmul_pref_attribute(
self.handle,
sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
(&size) as *const _ as *const _,
mem::size_of::<usize>(),
)?;
}
Ok(())
}
}
impl Drop for MatmulPref {
fn drop(&mut self) {
unsafe { result::destroy_matmul_pref(self.handle).expect("Unable to destroy matmul pref") }
}
}
/// [Matmul] super-trait
pub trait MatmulShared {
/// Returns a reference to the underlying cublasLt handle.
fn handle(&self) -> &sys::cublasLtHandle_t;
/// Returns a reference to the underlying cublasLt workspace
fn workspace(&self) -> &Workspace;
/// Returns a reference to the underlying stream
fn stream(&self) -> &CUstream;
}
/// Configuration for [Matmul]
#[derive(Debug, Copy, Clone)]
pub struct MatmulConfig {
pub transa: bool,
pub transb: bool,
pub m: u64,
pub n: u64,
pub k: u64,
pub alpha: f32,
pub lda: i64,
pub ldb: i64,
pub beta: f32,
pub ldc: i64,
pub stride_a: Option<i64>,
pub stride_b: Option<i64>,
pub stride_c: Option<i64>,
pub stride_bias: Option<i64>,
pub batch_size: Option<c_int>,
}
/// Matrix matrix multiplication with elements of type `T`.
pub trait Matmul<T>: MatmulShared {
/// Underlying CUDA Type for `T`
fn matrix_type() -> sys::cudaDataType;
/// Underlying CUDA Compute Type for `T`
fn compute_type() -> sys::cublasComputeType_t;
/// Matrix matrix multiplication. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul)
///
/// # Safety
/// This is unsafe because improper arguments may lead to invalid
/// memory accesses.
unsafe fn matmul<I: DevicePtr<T>, O: DevicePtrMut<T>>(
&self,
cfg: MatmulConfig,
a: &I,
b: &I,
c: &mut O,
bias: Option<&I>,
act: Option<&Activation>,
) -> Result<(), CublasError> {
//println!("cudarc/src/cublaslt/safe.rs:331 111");
let (a_rows, a_cols) = if cfg.transa {
(cfg.k, cfg.m)
} else {
(cfg.m, cfg.k)
};
let (b_rows, b_cols) = if cfg.transb {
(cfg.n, cfg.k)
} else {
(cfg.k, cfg.n)
};
//println!("cudarc/src/cublaslt/safe.rs:331 222");
// Creates matrix layouts
let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?;
if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) {
a_layout.set_batch(batch_size, stride_a)?;
}
//println!("The address of a_layout is: {:p}", &a_layout as *const _);
//println!("cudarc/src/cublaslt/safe.rs:331 333");
let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?;
if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) {
b_layout.set_batch(batch_size, stride_b)?;
}
//println!("The address of b_layout is: {:p}", &b_layout as *const _);
//println!("cudarc/src/cublaslt/safe.rs:331 444");
let c_layout = MatrixLayout::new(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
c_layout.set_batch(batch_size, stride_c)?;
}
//println!("The address of c_layout is: {:p}", &c_layout as *const _);
//println!("cudarc/src/cublaslt/safe.rs:331 555");
// Matmul description
let matmul_desc = MatmulDesc::new(Self::compute_type(), sys::cudaDataType_t::CUDA_R_32F)?;
//println!("The address of matmul_desc is: {:p}", &matmul_desc as *const _);
// Set transa
matmul_desc.set_transpose(cfg.transa, Matrix::A)?;
// Set transb
matmul_desc.set_transpose(cfg.transb, Matrix::B)?;
// Epilogue system can be leveraged to fuse add and activation operations
matmul_desc.set_epilogue(act, bias.map(|b| b.device_ptr()), cfg.stride_bias)?;
// Create matmul heuristic search preferences
let matmul_pref = MatmulPref::new()?;
// Set workspace size
matmul_pref.set_workspace_size(self.workspace().size)?;
//println!("cudarc/src/cublaslt/safe.rs:331 666");
// Get heuristic given Config, bias, act and workspace size
let heuristic = result::get_matmul_algo_heuristic(
*self.handle(),
matmul_desc.handle,
a_layout.handle,
b_layout.handle,
c_layout.handle,
c_layout.handle,
matmul_pref.handle,
)?;
//println!("cudarc/src/cublaslt/safe.rs:331 777");
// Launch matmul kernel
result::matmul(
*self.handle(),
matmul_desc.handle,
(&cfg.alpha) as *const _ as *const _,
(&cfg.beta) as *const _ as *const _,
*a.device_ptr() as *const _,
a_layout.handle,
*b.device_ptr() as *const _,
b_layout.handle,
*c.device_ptr_mut() as *const _,
c_layout.handle,
*c.device_ptr_mut() as *mut _,
c_layout.handle,
(&heuristic.algo) as *const _,
*self.workspace().buffer.device_ptr() as *const CUdeviceptr as *mut _,
self.workspace().size,
*self.stream() as *mut _,
)
}
}
impl MatmulShared for CudaBlasLT {
fn handle(&self) -> &sys::cublasLtHandle_t {
&self.handle
}
fn workspace(&self) -> &Workspace {
&self.workspace
}
fn stream(&self) -> &CUstream {
&self.device.stream
}
}
impl Matmul<f32> for CudaBlasLT {
fn matrix_type() -> sys::cudaDataType {
sys::cudaDataType_t::CUDA_R_32F
}
fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
}
}
#[cfg(feature = "f16")]
impl Matmul<half::f16> for CudaBlasLT {
fn matrix_type() -> sys::cudaDataType {
sys::cudaDataType_t::CUDA_R_16F
}
fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
}
}
#[cfg(feature = "f16")]
impl Matmul<half::bf16> for CudaBlasLT {
fn matrix_type() -> sys::cudaDataType {
sys::cudaDataType_t::CUDA_R_16BF
}
fn compute_type() -> sys::cublasComputeType_t {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::needless_range_loop)]
use super::*;
use std::ffi::CString;
fn matmul_truth<T, const M: usize, const N: usize, const K: usize>(
alpha: T,
a: &[[T; K]; M],
b: &[[T; N]; K],
beta: T,
c: &mut [[T; N]; M],
) where
T: Copy + Clone + std::ops::AddAssign + std::ops::MulAssign + std::ops::Mul<T, Output = T>,
{
for m in 0..M {
for n in 0..N {
c[m][n] *= beta;
}
}
for m in 0..M {
for n in 0..N {
for k in 0..K {
c[m][n] += alpha * a[m][k] * b[k][n];
}
}
}
}
#[test]
fn test_matmul_f32() {
let logpath = CString::new("log_matmul_f32").unwrap();
unsafe { sys::cublasLtLoggerSetLevel(4).result().unwrap() };
unsafe {
sys::cublasLtLoggerOpenFile(logpath.as_ptr())
.result()
.unwrap()
};
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlasLT::new(dev.clone()).unwrap();
const M: usize = 3;
const K: usize = 4;
const N: usize = 5;
let a: [[f32; K]; M] = [
[-0.5944882, 1.8055636, 0.52204555, -0.00397902],
[-0.38346434, -0.38013917, 0.4198623, -0.22479166],
[-1.6661372, -0.4568837, -0.9043474, 0.39125723],
];
let b: [[f32; N]; K] = [
[1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938],
[1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096],
[1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629],
[3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792],
];
let mut c: [[f32; N]; M] = [[0.0; N]; M];
matmul_truth(1.0, &a, &b, 0.0, &mut c);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<f32>(&[
-0.5944882, 1.8055636, 0.52204555, -0.00397902,
-0.38346434, -0.38013917, 0.4198623, -0.22479166,
-1.6661372, -0.4568837, -0.9043474, 0.39125723,
]).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<f32>(&[
1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938,
1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096,
1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629,
3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792,
]).unwrap();
#[rustfmt::skip]
let bias = dev.alloc_zeros::<f32>(N).unwrap();
let mut c_dev = dev.alloc_zeros::<f32>(M * N).unwrap();
unsafe {
blas.matmul(
MatmulConfig {
transa: false,
transb: false,
m: N as u64,
n: M as u64,
k: K as u64,
alpha: 1.0,
lda: N as i64,
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
&mut c_dev,
Some(&bias),
None,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
let found = c_host[m * N + n];
let expected = c[m][n];
assert!(
(found - expected) <= 1e-6,
"found={found:?}, expected={expected:?}"
);
}
}
}
#[cfg(feature = "f16")]
#[test]
fn test_matmul_half() {
let logpath = CString::new("log_matmul_half").unwrap();
unsafe { sys::cublasLtLoggerSetLevel(4).result().unwrap() };
unsafe {
sys::cublasLtLoggerOpenFile(logpath.as_ptr())
.result()
.unwrap()
};
let dev = CudaDevice::new(0).unwrap();
let blas = CudaBlasLT::new(dev.clone()).unwrap();
const M: usize = 2;
const K: usize = 4;
const N: usize = 6;
let a: [[half::f16; K]; M] = [
[-0.5944882, 1.8055636, 0.52204555, -0.00397902],
[-0.38346434, -0.38013917, 0.4198623, -0.22479166],
]
.map(|r| r.map(half::f16::from_f32));
let b: [[half::f16; N]; K] = [
[
1.1292169,
-0.13450263,
0.62789696,
-0.5685516,
0.21946938,
-1.6661372,
],
[
1.0585804,
-0.39789402,
0.90205914,
0.989318,
-0.3443096,
-0.4568837,
],
[
1.3412506,
0.3059701,
-0.9714474,
-0.36113533,
-1.6809629,
-0.9043474,
],
[
3.4746711,
-1.0930681,
0.16502666,
-0.59988785,
0.41375792,
0.39125723,
],
]
.map(|r| r.map(half::f16::from_f32));
let mut c: [[half::f16; N]; M] = [[0.0; N]; M].map(|r| r.map(half::f16::from_f32));
matmul_truth(
half::f16::from_f32(1.0),
&a,
&b,
half::f16::from_f32(0.0),
&mut c,
);
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<half::f16>(&[
-0.5944882, 1.8055636, 0.52204555, -0.00397902,
-0.38346434, -0.38013917, 0.4198623, -0.22479166,
].map(half::f16::from_f32)).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<half::f16>(&[
1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938, -1.6661372,
1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096, -0.4568837,
1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629, -0.9043474,
3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792, 0.39125723,
].map(half::f16::from_f32)).unwrap();
let bias = dev.alloc_zeros::<half::f16>(N).unwrap();
let mut c_dev = dev.alloc_zeros::<half::f16>(M * N).unwrap();
unsafe {
blas.matmul(
MatmulConfig {
transa: false,
transb: false,
m: N as u64,
n: M as u64,
k: K as u64,
alpha: 1.0,
lda: N as i64,
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
&mut c_dev,
Some(&bias),
None,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
let found = c_host[m * N + n];
let expected = c[m][n];
assert!(
(found - expected) <= half::f16::from_f32(1e-2),
"found={found:?}, expected={expected:?}"
);
}
}
#[rustfmt::skip]
let a_dev = dev.htod_sync_copy::<half::bf16>(&[
-0.5944882, 1.8055636, 0.52204555, -0.00397902,
-0.38346434, -0.38013917, 0.4198623, -0.22479166,
].map(half::bf16::from_f32)).unwrap();
#[rustfmt::skip]
let b_dev = dev.htod_sync_copy::<half::bf16>(&[
1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938, -1.6661372,
1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096, -0.4568837,
1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629, -0.9043474,
3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792, 0.39125723,
].map(half::bf16::from_f32)).unwrap();
let bias = dev.alloc_zeros::<half::bf16>(N).unwrap();
let mut c_dev = dev.alloc_zeros::<half::bf16>(M * N).unwrap();
unsafe {
blas.matmul(
MatmulConfig {
transa: false,
transb: false,
m: N as u64,
n: M as u64,
k: K as u64,
alpha: 1.0,
lda: N as i64,
ldb: K as i64,
beta: 0.0,
ldc: N as i64,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
},
&b_dev,
&a_dev,
&mut c_dev,
Some(&bias),
None,
)
}
.unwrap();
let c_host = dev.sync_reclaim(c_dev).unwrap();
for m in 0..M {
for n in 0..N {
let found = c_host[m * N + n];
let expected = c[m][n];
assert!(
(half::bf16::to_f32(found) - half::f16::to_f32(expected)) <= 1e-2,
"found={found:?}, expected={expected:?}"
);
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment