Commit 3ec8c534 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
extern crate bindgen;
extern crate cc;
use std::{env, path::PathBuf, process::Command};
use bindgen::CargoCallbacks;
use regex::Regex;
fn main() {
// Tell cargo to invalidate the built crate whenever files of interest changes.
println!("cargo:rerun-if-changed={}", "cuda");
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
// Specify the desired architecture version.
let arch = "compute_86"; // For example, using SM 8.6 (Ampere architecture).
let code = "sm_86"; // For the same SM 8.6 (Ampere architecture).
// build the cuda kernels
let cuda_src = PathBuf::from("src/cuda/kernels/my_struct_kernel.cu");
let ptx_file = out_dir.join("my_struct_kernel.ptx");
let nvcc_status = Command::new("nvcc")
.arg("-ptx")
.arg("-o")
.arg(&ptx_file)
.arg(&cuda_src)
.arg(format!("-arch={}", arch))
.arg(format!("-code={}", code))
.status()
.unwrap();
assert!(
nvcc_status.success(),
"Failed to compile CUDA source to PTX."
);
// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("src/cuda/includes/wrapper.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(CargoCallbacks))
// we use "no_copy" and "no_debug" here because we don't know if we can safely generate them for our structs in C code (they may contain raw pointers)
.no_copy("*")
.no_debug("*")
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");
// we need to make modifications to the generated code
let generated_bindings = bindings.to_string();
// Regex to find raw pointers to float and replace them with CudaSlice<f32>
// You can copy this regex to add/modify other types of pointers, for example "*mut i32"
let pointer_regex = Regex::new(r"\*mut f32").unwrap();
let modified_bindings = pointer_regex.replace_all(&generated_bindings, "CudaSlice<f32>");
// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
std::fs::write(out_path.join("bindings.rs"), modified_bindings.as_bytes())
.expect("Failed to write bindings");
}
// my_struct.h
#ifndef MY_STRUCT_H
#define MY_STRUCT_H
struct MyStruct
{
/// example data containing fixed length array of floats
float data[4];
};
#endif // MY_STRUCT_H
#include "my_struct.h"
\ No newline at end of file
#include "../includes/my_struct.h"
extern "C" __global__ void my_struct_kernel(MyStruct *my_structs, const size_t n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n)
{
my_structs[i].data[0] += 1;
my_structs[i].data[1] += 1;
my_structs[i].data[2] += 1;
my_structs[i].data[3] += 1;
}
}
\ No newline at end of file
#![allow(non_snake_case)]
//! This file outlines a typical build process which can be used for more complex CUDA projects utilising this crate.
//! It does the following:
//! 1. Use a `build.rs` file to compile your CUDA code/project into a PTX file. Your CUDA code/project can be as complicated as you need them to be, including multiple files, with headers for your struct definitions, each kernel in it's own file, etc.
//! 2. The build process compiles the kernels into a PTX file, which is written to the output directory
//! 3. The build process then uses the `bindgen` crate to generate Rust bindings for the structs defined in your CUDA code
//! 4. In the `main.rs` code, the PTX code is included as a string via the `!include_str` macro, which is then compiled using the functions in this crate (detailed in previous examples)
//!
//! The advantages of having this build process for more complex CUDA projects:
//! - You only need to define your structs once, in your CUDA code, and the Rust bindings are generated automatically
//! - You have full intellisense for your CUDA code since they can be stored under a separate folder or even as part of a separate project
//!
//! There are two files in this example: `main.rs` and `build.rs`. You can reference them and add to your project accordingly. The `cuda` folder in this example gives a simple example of defining structs in a separate header, including creating a `wrapper.h` header for `bindgen`
use std::time::Instant;
use cudarc::driver::{CudaDevice, LaunchConfig, DeviceRepr, DriverError, LaunchAsync};
use cudarc::nvrtc::Ptx;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
unsafe impl DeviceRepr for MyStruct {}
impl Default for MyStruct {
fn default() -> Self{
Self{ data: [0.0; 4]}
}
}
// include the compiled PTX code as string
const CUDA_KERNEL_MY_STRUCT: &str = include_str!(concat!(env!("OUT_DIR"), "/my_struct_kernel.ptx"));
fn main() -> Result<(), DriverError> {
// setup GPU device
let now = Instant::now();
let gpu = CudaDevice::new(0)?;
println!("Time taken to initialise CUDA: {:.2?}", now.elapsed());
// compile ptx
let now = Instant::now();
let ptx = Ptx::from_src(CUDA_KERNEL_MY_STRUCT);
gpu.load_ptx(ptx, "my_module", &["my_struct_kernel"])?;
println!("Time taken to compile and load PTX: {:.2?}", now.elapsed());
// create data
let now = Instant::now();
let n = 10_usize;
let my_structs = vec![MyStruct { data: [1.0; 4] }; n];
// copy to GPU
let gpu_my_structs = gpu.htod_copy(my_structs)?;
println!("Time taken to initialise data: {:.2?}", now.elapsed());
let now = Instant::now();
let f = gpu.get_func("my_module", "my_struct_kernel").unwrap();
unsafe { f.launch(LaunchConfig::for_num_elems(n as u32), (&gpu_my_structs, n)) }?;
println!("Time taken to call kernel: {:.2?}", now.elapsed());
let my_structs = gpu.sync_reclaim(gpu_my_structs)?;
assert!(my_structs.iter().all(|i| i.data == [1.0; 4]));
Ok(())
}
use cudarc::driver::{CudaDevice, DriverError, LaunchAsync, LaunchConfig};
use cudarc::nvrtc::compile_ptx;
const PTX_SRC: &str = "
extern \"C\" __global__ void matmul(float* A, float* B, float* C, int N) {
int ROW = blockIdx.y*blockDim.y+threadIdx.y;
int COL = blockIdx.x*blockDim.x+threadIdx.x;
float tmpSum = 0;
if (ROW < N && COL < N) {
// each thread computes one element of the block sub-matrix
for (int i = 0; i < N; i++) {
tmpSum += A[ROW * N + i] * B[i * N + COL];
}
}
// printf(\"pos, (%d, %d) - N %d - value %d\\n\", ROW, COL, N, tmpSum);
C[ROW * N + COL] = tmpSum;
}
";
fn main() -> Result<(), DriverError> {
let start = std::time::Instant::now();
let ptx = compile_ptx(PTX_SRC).unwrap();
println!("Compilation succeeded in {:?}", start.elapsed());
let dev = CudaDevice::new(0)?;
println!("Built in {:?}", start.elapsed());
dev.load_ptx(ptx, "matmul", &["matmul"])?;
let f = dev.get_func("matmul", "matmul").unwrap();
println!("Loaded in {:?}", start.elapsed());
let a_host = [1.0f32, 2.0, 3.0, 4.0];
let b_host = [1.0f32, 2.0, 3.0, 4.0];
let mut c_host = [0.0f32; 4];
let a_dev = dev.htod_sync_copy(&a_host)?;
let b_dev = dev.htod_sync_copy(&b_host)?;
let mut c_dev = dev.htod_sync_copy(&c_host)?;
println!("Copied in {:?}", start.elapsed());
let cfg = LaunchConfig {
block_dim: (2, 2, 1),
grid_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
unsafe { f.launch(cfg, (&a_dev, &b_dev, &mut c_dev, 2i32)) }?;
dev.dtoh_sync_copy_into(&c_dev, &mut c_host)?;
println!("Found {:?} in {:?}", c_host, start.elapsed());
Ok(())
}
use cudarc::nvrtc::{compile_ptx_with_opts, CompileError, CompileOptions};
fn main() -> Result<(), CompileError> {
let opts = CompileOptions {
ftz: Some(true),
prec_div: Some(false),
prec_sqrt: Some(false),
fmad: Some(true),
..Default::default()
};
let _ = compile_ptx_with_opts(
"
extern \"C\" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}",
opts,
)?;
println!("Compilation succeeded!");
Ok(())
}
extern "C" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}
\ No newline at end of file
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-29745058
// Cuda compilation tools, release 11.3, V11.3.58
// Based on NVVM 7.0.1
//
.version 7.3
.target sm_52
.address_size 64
// .globl sin_kernel
.global .align 4 .b8 __cudart_i2opi_f[24] = {65, 144, 67, 60, 153, 149, 98, 219, 192, 221, 52, 245, 209, 87, 39, 252, 41, 21, 68, 78, 110, 131, 249, 162};
.visible .entry sin_kernel(
.param .u64 sin_kernel_param_0,
.param .u64 sin_kernel_param_1,
.param .u32 sin_kernel_param_2
)
{
.local .align 4 .b8 __local_depot0[28];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<12>;
.reg .f32 %f<38>;
.reg .b32 %r<53>;
.reg .f64 %fd<3>;
.reg .b64 %rd<33>;
mov.u64 %SPL, __local_depot0;
ld.param.u64 %rd10, [sin_kernel_param_0];
ld.param.u64 %rd11, [sin_kernel_param_1];
ld.param.u32 %r19, [sin_kernel_param_2];
add.u64 %rd1, %SPL, 0;
mov.u32 %r20, %ntid.x;
mov.u32 %r21, %ctaid.x;
mov.u32 %r22, %tid.x;
mad.lo.s32 %r1, %r21, %r20, %r22;
setp.ge.s32 %p1, %r1, %r19;
@%p1 bra $L__BB0_14;
cvta.to.global.u64 %rd13, %rd11;
cvt.s64.s32 %rd2, %r1;
mul.wide.s32 %rd14, %r1, 4;
add.s64 %rd15, %rd13, %rd14;
ld.global.f32 %f1, [%rd15];
mul.f32 %f14, %f1, 0f3F22F983;
cvt.rni.s32.f32 %r52, %f14;
cvt.rn.f32.s32 %f15, %r52;
mov.f32 %f16, 0fBFC90FDA;
fma.rn.f32 %f17, %f15, %f16, %f1;
mov.f32 %f18, 0fB3A22168;
fma.rn.f32 %f19, %f15, %f18, %f17;
mov.f32 %f20, 0fA7C234C5;
fma.rn.f32 %f35, %f15, %f20, %f19;
abs.f32 %f3, %f1;
setp.leu.f32 %p2, %f3, 0f47CE4780;
@%p2 bra $L__BB0_9;
setp.eq.f32 %p3, %f3, 0f7F800000;
@%p3 bra $L__BB0_8;
bra.uni $L__BB0_3;
$L__BB0_8:
mov.f32 %f23, 0f00000000;
mul.rn.f32 %f35, %f1, %f23;
bra.uni $L__BB0_9;
$L__BB0_3:
mov.b32 %r3, %f1;
bfe.u32 %r24, %r3, 23, 8;
add.s32 %r4, %r24, -128;
shl.b32 %r25, %r3, 8;
or.b32 %r5, %r25, -2147483648;
shr.u32 %r6, %r4, 5;
mov.u64 %rd32, 0;
mov.u32 %r49, 0;
mov.u64 %rd30, __cudart_i2opi_f;
mov.u64 %rd31, %rd1;
$L__BB0_4:
.pragma "nounroll";
ld.global.nc.u32 %r26, [%rd30];
mad.wide.u32 %rd18, %r26, %r5, %rd32;
shr.u64 %rd32, %rd18, 32;
st.local.u32 [%rd31], %rd18;
add.s64 %rd31, %rd31, 4;
add.s64 %rd30, %rd30, 4;
add.s32 %r49, %r49, 1;
setp.ne.s32 %p4, %r49, 6;
@%p4 bra $L__BB0_4;
st.local.u32 [%rd1+24], %rd32;
cvt.u64.u32 %rd19, %r6;
mov.u64 %rd20, 2;
sub.s64 %rd21, %rd20, %rd19;
shl.b64 %rd22, %rd21, 2;
add.s64 %rd23, %rd1, %rd22;
add.s64 %rd9, %rd23, 16;
ld.local.u32 %r50, [%rd23+16];
ld.local.u32 %r51, [%rd23+12];
and.b32 %r11, %r4, 31;
setp.eq.s32 %p5, %r11, 0;
@%p5 bra $L__BB0_7;
mov.u32 %r27, 32;
sub.s32 %r28, %r27, %r11;
shr.u32 %r29, %r51, %r28;
shl.b32 %r30, %r50, %r11;
add.s32 %r50, %r29, %r30;
ld.local.u32 %r31, [%rd9+-8];
shr.u32 %r32, %r31, %r28;
shl.b32 %r33, %r51, %r11;
add.s32 %r51, %r32, %r33;
$L__BB0_7:
and.b32 %r34, %r3, -2147483648;
shr.u32 %r35, %r51, 30;
shl.b32 %r36, %r50, 2;
or.b32 %r37, %r35, %r36;
shr.u32 %r38, %r37, 31;
shr.u32 %r39, %r50, 30;
add.s32 %r40, %r38, %r39;
neg.s32 %r41, %r40;
setp.eq.s32 %p6, %r34, 0;
selp.b32 %r52, %r40, %r41, %p6;
setp.ne.s32 %p7, %r38, 0;
xor.b32 %r42, %r34, -2147483648;
selp.b32 %r43, %r42, %r34, %p7;
selp.b32 %r44, -1, 0, %p7;
xor.b32 %r45, %r37, %r44;
shl.b32 %r46, %r51, 2;
xor.b32 %r47, %r46, %r44;
cvt.u64.u32 %rd24, %r45;
cvt.u64.u32 %rd25, %r47;
bfi.b64 %rd26, %rd24, %rd25, 32, 32;
cvt.rn.f64.s64 %fd1, %rd26;
mul.f64 %fd2, %fd1, 0d3BF921FB54442D19;
cvt.rn.f32.f64 %f21, %fd2;
setp.eq.s32 %p8, %r43, 0;
neg.f32 %f22, %f21;
selp.f32 %f35, %f21, %f22, %p8;
$L__BB0_9:
and.b32 %r18, %r52, 1;
setp.eq.s32 %p9, %r18, 0;
selp.f32 %f7, %f35, 0f3F800000, %p9;
mul.rn.f32 %f8, %f35, %f35;
mov.f32 %f36, 0fB94D4153;
@%p9 bra $L__BB0_11;
mov.f32 %f25, 0fBAB607ED;
mov.f32 %f26, 0f37CBAC00;
fma.rn.f32 %f36, %f26, %f8, %f25;
$L__BB0_11:
selp.f32 %f27, 0f3C0885E4, 0f3D2AAABB, %p9;
fma.rn.f32 %f28, %f36, %f8, %f27;
selp.f32 %f29, 0fBE2AAAA8, 0fBEFFFFFF, %p9;
fma.rn.f32 %f30, %f28, %f8, %f29;
mov.f32 %f31, 0f00000000;
fma.rn.f32 %f32, %f8, %f7, %f31;
fma.rn.f32 %f37, %f30, %f32, %f7;
and.b32 %r48, %r52, 2;
setp.eq.s32 %p11, %r48, 0;
@%p11 bra $L__BB0_13;
mov.f32 %f34, 0fBF800000;
fma.rn.f32 %f37, %f37, %f34, %f31;
$L__BB0_13:
cvta.to.global.u64 %rd27, %rd10;
shl.b64 %rd28, %rd2, 2;
add.s64 %rd29, %rd27, %rd28;
st.global.f32 [%rd29], %f37;
$L__BB0_14:
ret;
}
#!/bin/bash
set -exu
bindgen \
--whitelist-type="^cublas.*" \
--whitelist-function="^cublas.*" \
--default-enum-style=rust \
--no-doc-comments \
--with-derive-default \
--with-derive-eq \
--with-derive-hash \
--with-derive-ord \
--size_t-is-usize \
--use-core \
wrapper.h -- -I/usr/local/cuda/include \
> sys.rs
\ No newline at end of file
use super::sys;
use core::ffi::{c_int, c_longlong};
use half::f16;
extern "C" {
pub fn cublasHgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f16,
A: *const f16,
lda: c_int,
B: *const f16,
ldb: c_int,
beta: *const f16,
C: *mut f16,
ldc: c_int,
) -> sys::cublasStatus_t;
}
extern "C" {
pub fn cublasHgemmStridedBatched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f16,
A: *const f16,
lda: c_int,
strideA: c_longlong,
B: *const f16,
ldb: c_int,
strideB: c_longlong,
beta: *const f16,
C: *mut f16,
ldc: c_int,
strideC: c_longlong,
batchCount: c_int,
) -> sys::cublasStatus_t;
}
//! Wrappers around the [cublas API](https://docs.nvidia.com/cuda/cublas/index.html),
//! in three levels. See crate documentation for description of each.
#[cfg(feature = "f16")]
pub mod half;
pub mod result;
pub mod safe;
#[allow(warnings)]
pub mod sys;
pub use safe::*;
use super::sys;
use core::ffi::{c_int, c_longlong, c_void};
use core::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CublasError(pub sys::cublasStatus_t);
impl sys::cublasStatus_t {
pub fn result(self) -> Result<(), CublasError> {
match self {
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
_ => Err(CublasError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CublasError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CublasError {}
/// Creates a handle to the cuBLAS library. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublascreate)
pub fn create_handle() -> Result<sys::cublasHandle_t, CublasError> {
let mut handle = MaybeUninit::uninit();
unsafe {
sys::cublasCreate_v2(handle.as_mut_ptr()).result()?;
Ok(handle.assume_init())
}
}
/// Destroys a handle previously created with [create_handle()]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasdestroy)
///
/// # Safety
///
/// `handle` must not have been freed already.
pub unsafe fn destroy_handle(handle: sys::cublasHandle_t) -> Result<(), CublasError> {
sys::cublasDestroy_v2(handle).result()
}
/// Sets the stream cuBLAS will use. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublassetstream)
///
/// # Safety
///
/// `handle` and `stream` must be valid.
pub unsafe fn set_stream(
handle: sys::cublasHandle_t,
stream: sys::cudaStream_t,
) -> Result<(), CublasError> {
sys::cublasSetStream_v2(handle, stream).result()
}
/// Single precision matrix vector multiplication. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemv)
///
/// # Safety
///
/// - `a`, `x`, and `y` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn sgemv(
handle: sys::cublasHandle_t,
trans: sys::cublasOperation_t,
m: c_int,
n: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
x: *const f32,
incx: c_int,
beta: *const f32,
y: *mut f32,
incy: c_int,
) -> Result<(), CublasError> {
sys::cublasSgemv_v2(handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy).result()
}
/// Double precision matrix vector multiplication. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemv)
///
/// # Safety
///
/// - `a`, `x`, and `y` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn dgemv(
handle: sys::cublasHandle_t,
trans: sys::cublasOperation_t,
m: c_int,
n: c_int,
alpha: *const f64,
a: *const f64,
lda: c_int,
x: *const f64,
incx: c_int,
beta: *const f64,
y: *mut f64,
incy: c_int,
) -> Result<(), CublasError> {
sys::cublasDgemv_v2(handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy).result()
}
#[cfg(feature = "f16")]
/// Half precision matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn hgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const half::f16,
a: *const half::f16,
lda: c_int,
b: *const half::f16,
ldb: c_int,
beta: *const half::f16,
c: *mut half::f16,
ldc: c_int,
) -> Result<(), CublasError> {
super::half::cublasHgemm(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
)
.result()
}
/// Single precision matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn sgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
b: *const f32,
ldb: c_int,
beta: *const f32,
c: *mut f32,
ldc: c_int,
) -> Result<(), CublasError> {
sys::cublasSgemm_v2(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
)
.result()
}
/// Double precision matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn dgemm(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f64,
a: *const f64,
lda: c_int,
b: *const f64,
ldb: c_int,
beta: *const f64,
c: *mut f64,
ldc: c_int,
) -> Result<(), CublasError> {
sys::cublasDgemm_v2(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
)
.result()
}
#[cfg(feature = "f16")]
/// Half precision batched matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn hgemm_strided_batched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const half::f16,
a: *const half::f16,
lda: c_int,
stride_a: c_longlong,
b: *const half::f16,
ldb: c_int,
stride_b: c_longlong,
beta: *const half::f16,
c: *mut half::f16,
ldc: c_int,
stride_c: c_longlong,
batch_size: c_int,
) -> Result<(), CublasError> {
super::half::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc,
stride_c, batch_size,
)
.result()
}
/// Single precision batched matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn sgemm_strided_batched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f32,
a: *const f32,
lda: c_int,
stride_a: c_longlong,
b: *const f32,
ldb: c_int,
stride_b: c_longlong,
beta: *const f32,
c: *mut f32,
ldc: c_int,
stride_c: c_longlong,
batch_size: c_int,
) -> Result<(), CublasError> {
sys::cublasSgemmStridedBatched(
handle, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc,
stride_c, batch_size,
)
.result()
}
/// Double precision batched matmul. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmstridedbatched)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn dgemm_strided_batched(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const f64,
a: *const f64,
lda: c_int,
stride_a: c_longlong,
b: *const f64,
ldb: c_int,
stride_b: c_longlong,
beta: *const f64,
c: *mut f64,
ldc: c_int,
stride_c: c_longlong,
batch_size: c_int,
) -> Result<(), CublasError> {
sys::cublasDgemmStridedBatched(
handle, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc,
stride_c, batch_size,
)
.result()
}
/// Matmul with data types specified as parameters. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn gemm_ex(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
a_type: sys::cudaDataType,
lda: c_int,
b: *const c_void,
b_type: sys::cudaDataType,
ldb: c_int,
beta: *const c_void,
c: *mut c_void,
c_type: sys::cudaDataType,
ldc: c_int,
compute_type: sys::cublasComputeType_t,
algo: sys::cublasGemmAlgo_t,
) -> Result<(), CublasError> {
sys::cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
a,
a_type,
lda,
b,
b_type,
ldb,
beta,
c,
c_type,
ldc,
compute_type,
algo,
)
.result()
}
/// Strided batched matmul with data types specified as parameters. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex)
///
/// # Safety
///
/// - `a`, `b`, and `c` must be valid device pointers that have not been freed.
/// - `alpha` and `beta` can be pointers to host memory, but must be not null
/// - the strides and sizes must be sized correctly
#[allow(clippy::too_many_arguments)]
pub unsafe fn gemm_strided_batched_ex(
handle: sys::cublasHandle_t,
transa: sys::cublasOperation_t,
transb: sys::cublasOperation_t,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_void,
a: *const c_void,
a_type: sys::cudaDataType,
lda: c_int,
stride_a: c_longlong,
b: *const c_void,
b_type: sys::cudaDataType,
ldb: c_int,
stride_b: c_longlong,
beta: *const c_void,
c: *mut c_void,
c_type: sys::cudaDataType,
ldc: c_int,
stride_c: c_longlong,
batch_count: c_int,
compute_type: sys::cublasComputeType_t,
algo: sys::cublasGemmAlgo_t,
) -> Result<(), CublasError> {
sys::cublasGemmStridedBatchedEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
a,
a_type,
lda,
stride_a,
b,
b_type,
ldb,
stride_b,
beta,
c,
c_type,
ldc,
stride_c,
batch_count,
compute_type,
algo,
)
.result()
}
This diff is collapsed.
This diff is collapsed.
#include "cublas_v2.h"
\ No newline at end of file
#!/bin/bash
# Requires rust-bindgen 0.68.1 or superior
set -exu
BINDGEN_EXTRA_CLANG_ARGS="-D__CUDA_BF16_TYPES_EXIST__" \
bindgen \
--allowlist-type="^cublasLt.*" \
--allowlist-var="^cublasLt.*" \
--allowlist-function="^cublasLt.*" \
--default-enum-style=rust \
--no-doc-comments \
--with-derive-default \
--with-derive-eq \
--with-derive-hash \
--with-derive-ord \
--use-core \
wrapper.h -- -I/usr/local/cuda/include \
> sys.rs
pub mod result;
pub mod safe;
#[allow(warnings)]
pub mod sys;
pub use safe::*;
use super::sys;
use crate::cublaslt::sys::cublasLtMatmulAlgo_t;
use core::ffi::c_void;
use core::mem::MaybeUninit;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CublasError(pub sys::cublasStatus_t);
impl sys::cublasStatus_t {
pub fn result(self) -> Result<(), CublasError> {
match self {
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
_ => Err(CublasError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CublasError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CublasError {}
/// Creates a handle to the cuBLASLT library. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltcreate)
pub fn create_handle() -> Result<sys::cublasLtHandle_t, CublasError> {
let mut handle = MaybeUninit::uninit();
unsafe {
sys::cublasLtCreate(handle.as_mut_ptr()).result()?;
Ok(handle.assume_init())
}
}
/// Destroys a handle previously created with [create_handle()]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltdestroy)
///
/// # Safety
///
/// `handle` must not have been freed already.
pub unsafe fn destroy_handle(handle: sys::cublasLtHandle_t) -> Result<(), CublasError> {
sys::cublasLtDestroy(handle).result()
}
/// Creates a matrix layout descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutcreate)
pub fn create_matrix_layout(
matrix_type: sys::cudaDataType,
rows: u64,
cols: u64,
ld: i64,
) -> Result<sys::cublasLtMatrixLayout_t, CublasError> {
let mut matrix_layout = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatrixLayoutCreate(matrix_layout.as_mut_ptr(), matrix_type, rows, cols, ld)
.result()?;
Ok(matrix_layout.assume_init())
}
}
/// Sets the value of the specified attribute belonging to a previously created matrix layout
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutsetattribute)
///
/// # Safety
/// `matrix_layout` must not have been freed already.
pub unsafe fn set_matrix_layout_attribute(
matrix_layout: sys::cublasLtMatrixLayout_t,
attr: sys::cublasLtMatrixLayoutAttribute_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
//println!("set_matrix_layout_attribute The address of buf is: {:p}", buf);
sys::cublasLtMatrixLayoutSetAttribute(matrix_layout, attr, buf, buf_size).result()
}
/// Destroys a matrix layout previously created with [create_matrix_layout(...)]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatrixlayoutdestroy)
///
/// # Safety
///
/// `matrix_layout` must not have been freed already.
pub unsafe fn destroy_matrix_layout(
matrix_layout: sys::cublasLtMatrixLayout_t,
) -> Result<(), CublasError> {
sys::cublasLtMatrixLayoutDestroy(matrix_layout).result()
}
/// Creates a matrix multiply descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldesccreate)
pub fn create_matmul_desc(
compute_type: sys::cublasComputeType_t,
scale_type: sys::cudaDataType,
) -> Result<sys::cublasLtMatmulDesc_t, CublasError> {
let mut matmul_desc = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatmulDescCreate(matmul_desc.as_mut_ptr(), compute_type, scale_type)
.result()?;
Ok(matmul_desc.assume_init())
}
}
/// Sets the value of the specified attribute belonging to a previously created matrix multiply
/// descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldescsetattribute)
///
/// # Safety
/// `matmul_desc` must not be freed already.
pub unsafe fn set_matmul_desc_attribute(
matmul_desc: sys::cublasLtMatmulDesc_t,
attr: sys::cublasLtMatmulDescAttributes_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
//println!("set_matmul_desc_attribute The address of buf is: {:p}", buf);
sys::cublasLtMatmulDescSetAttribute(matmul_desc, attr, buf, buf_size).result()
}
/// Destroys a matrix multiply descriptor previously created with [create_matmul_desc(...)]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldescdestroy)
///
/// # Safety
///
/// `matmul_desc` must not have been freed already.
pub unsafe fn destroy_matmul_desc(
matmul_desc: sys::cublasLtMatmulDesc_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmulDescDestroy(matmul_desc).result()
}
/// Creates a matrix multiply heuristic search preferences descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencecreate)
pub fn create_matmul_pref() -> Result<sys::cublasLtMatmulPreference_t, CublasError> {
let mut matmul_pref = MaybeUninit::uninit();
unsafe {
sys::cublasLtMatmulPreferenceCreate(matmul_pref.as_mut_ptr()).result()?;
Ok(matmul_pref.assume_init())
}
}
/// Sets the value of the specified attribute belonging to a previously create matrix multiply
/// preferences descriptor. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencesetattribute)
///
/// # Safety
/// `matmul_pref` must not have been freed already.
pub unsafe fn set_matmul_pref_attribute(
matmul_pref: sys::cublasLtMatmulPreference_t,
attr: sys::cublasLtMatmulPreferenceAttributes_t,
buf: *const c_void,
buf_size: usize,
) -> Result<(), CublasError> {
sys::cublasLtMatmulPreferenceSetAttribute(matmul_pref, attr, buf, buf_size).result()
}
/// Destroys a matrix multiply preferences descriptor previously created
/// with [create_matmul_pref()]. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulpreferencedestroy)
///
/// # Safety
///
/// `matmul_pref` must not have been freed already.
pub unsafe fn destroy_matmul_pref(
matmul_pref: sys::cublasLtMatmulPreference_t,
) -> Result<(), CublasError> {
sys::cublasLtMatmulPreferenceDestroy(matmul_pref).result()
}
/// Retrieves the fastest possible algorithm for the matrix multiply operation function
/// given input matrices A, B and C and the output matrix D. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmulalgogetheuristic)
///
/// # Safety
/// All the parameters must not have been freed already & must be valid layouts for allocations.
pub unsafe fn get_matmul_algo_heuristic(
handle: sys::cublasLtHandle_t,
matmul_desc: sys::cublasLtMatmulDesc_t,
a_layout: sys::cublasLtMatrixLayout_t,
b_layout: sys::cublasLtMatrixLayout_t,
c_layout: sys::cublasLtMatrixLayout_t,
d_layout: sys::cublasLtMatrixLayout_t,
matmul_pref: sys::cublasLtMatmulPreference_t,
) -> Result<sys::cublasLtMatmulHeuristicResult_t, CublasError> {
//println!("get_matmul_algo_heuristic111");
let mut matmul_heuristic = MaybeUninit::uninit();
let mut algo_count = 0;
//println!("get_matmul_algo_heuristic222");
sys::cublasLtMatmulAlgoGetHeuristic(
handle,
matmul_desc,
a_layout,
b_layout,
c_layout,
d_layout,
matmul_pref,
1, // only select the fastest algo
matmul_heuristic.as_mut_ptr(),
&mut algo_count,
)
.result()?;
//println!("get_matmul_algo_heuristic333 algo_count:{}",algo_count);
if algo_count == 0 {
return Err(CublasError(
sys::cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED,
));
}
//println!("get_matmul_algo_heuristic444");
let matmul_heuristic = matmul_heuristic.assume_init();
matmul_heuristic.state.result()?;
//println!("get_matmul_algo_heuristic555");
Ok(matmul_heuristic)
}
/// Computes the matrix multiplication of matrics A and B to produce the output matrix D,
/// according to the following operation: D = alpha*(A*B) + beta*(C)
/// where A, B, and C are input matrices, and alpha and beta are input scalars. See
/// [nvidia docs](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul)
///
/// # Safety
/// All the sys objects can't have been freed already.
#[allow(clippy::too_many_arguments)]
pub unsafe fn matmul(
handle: sys::cublasLtHandle_t,
matmul_desc: sys::cublasLtMatmulDesc_t,
alpha: *const c_void,
beta: *const c_void,
a: *const c_void,
a_layout: sys::cublasLtMatrixLayout_t,
b: *const c_void,
b_layout: sys::cublasLtMatrixLayout_t,
c: *const c_void,
c_layout: sys::cublasLtMatrixLayout_t,
d: *mut c_void,
d_layout: sys::cublasLtMatrixLayout_t,
algo: *const cublasLtMatmulAlgo_t,
workspace: *mut c_void,
workspace_size: usize,
stream: sys::cudaStream_t,
) -> Result<(), CublasError> {
//println!("cudarc src/cublaslt/result.rs 240 1");
sys::cublasLtMatmul(
handle,
matmul_desc,
alpha,
a,
a_layout,
b,
b_layout,
beta,
c,
c_layout,
d,
d_layout,
algo,
workspace,
workspace_size,
stream,
)
.result()
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment