Commit 3ec8c534 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
//! A thin wrapper around [sys] providing [Result]s with [NvrtcError].
use super::sys;
use core::{
ffi::{c_char, c_int, CStr},
mem::MaybeUninit,
};
use std::{ffi::CString, vec::Vec};
/// Wrapper around [sys::nvrtcResult]. See
/// [nvrtcResult docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__error_1g31e41ef222c0ea75b4c48f715b3cd9f0)
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NvrtcError(pub sys::nvrtcResult);
impl sys::nvrtcResult {
/// Transforms into a [Result] of [NvrtcError]
pub fn result(self) -> Result<(), NvrtcError> {
match self {
sys::nvrtcResult::NVRTC_SUCCESS => Ok(()),
_ => Err(NvrtcError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for NvrtcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for NvrtcError {}
/// Creates a program from source code `src`. This should be source code from a .cu file.
///
/// See [nvrtcCreateProgram() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1g9ae65f68911d1cf0adda2af4ad8cb458)
///
/// Example:
/// ```rust
/// # use cudarc::nvrtc::result::*;
/// let prog = create_program("extern \"C\" __global__ void kernel() { }").unwrap();
/// ```
pub fn create_program<S: AsRef<str>>(src: S) -> Result<sys::nvrtcProgram, NvrtcError> {
let src_c = CString::new(src.as_ref()).unwrap();
let mut prog = MaybeUninit::uninit();
unsafe {
sys::nvrtcCreateProgram(
prog.as_mut_ptr(),
src_c.as_c_str().as_ptr(),
std::ptr::null(),
0,
std::ptr::null(),
std::ptr::null(),
)
.result()?;
Ok(prog.assume_init())
}
}
/// Compiles an already created program. Options should be of the form specified
/// in [nvrtc's supported compiler options](https://docs.nvidia.com/cuda/nvrtc/index.html#group__options).
///
/// See [nvrtcCompileProgram() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1g1f3136029db1413e362154b567297e8b)
///
/// Example:
///
/// ```rust
/// # use cudarc::nvrtc::result::*;
/// let prog = create_program("extern \"C\" __global__ void kernel() { }").unwrap();
/// unsafe { compile_program(prog, &["--ftz=true", "--fmad=true"]) }.unwrap();
/// ```
///
/// # Safety
///
/// `prog` must be created from [create_program()] and not have been freed by [destroy_program()].
pub unsafe fn compile_program<O: Clone + Into<Vec<u8>>>(
prog: sys::nvrtcProgram,
options: &[O],
) -> Result<(), NvrtcError> {
let c_strings: Vec<CString> = options
.iter()
.cloned()
.map(|o| CString::new(o).unwrap())
.collect();
let c_strs: Vec<&CStr> = c_strings.iter().map(CString::as_c_str).collect();
let opts: Vec<*const c_char> = c_strs.iter().cloned().map(CStr::as_ptr).collect();
sys::nvrtcCompileProgram(prog, opts.len() as c_int, opts.as_ptr()).result()
}
/// Releases resources associated with `prog`.
///
/// See [nvrtcDestroyProgram() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1gaa237c59615b7d4f48d5b308b5c9b140).
///
/// # Safety
///
/// `prog` must be created from [create_program()] and not have been freed by [destroy_program()].
pub unsafe fn destroy_program(prog: sys::nvrtcProgram) -> Result<(), NvrtcError> {
sys::nvrtcDestroyProgram(&prog as *const _ as *mut _).result()
}
/// Extract the ptx associated with `prog`. Call [compile_program()] before this.
///
/// See [nvrtcGetPTX() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1gc9a66bbbd47c256f4a8955517b3965da)
/// and [nvrtcGetPTXSize() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1gc622d6ffb6fff71e209407da19612c1a).
///
/// # Safety
///
/// `prog` must be created from [create_program()] and not have been freed by [destroy_program()].
#[allow(clippy::slow_vector_initialization)]
pub unsafe fn get_ptx(prog: sys::nvrtcProgram) -> Result<Vec<c_char>, NvrtcError> {
let mut size: usize = 0;
sys::nvrtcGetPTXSize(prog, &mut size as *mut _).result()?;
let mut ptx_src: Vec<c_char> = Vec::with_capacity(size);
ptx_src.resize(size, 0);
sys::nvrtcGetPTX(prog, ptx_src.as_mut_ptr()).result()?;
Ok(ptx_src)
}
/// Extract log from a compiled program.
///
/// See [nvrtcGetProgramLog() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1g74c550e5cab81efbd59e4f72579edbd1)
/// and [nvrtcGetProgramLogSize() docs](https://docs.nvidia.com/cuda/nvrtc/index.html#group__compilation_1g59944bb118095ab53eec8994d056a18d).
///
/// # Safety
///
/// `prog` must be created from [create_program()] and not have been freed by [destroy_program()].
#[allow(clippy::slow_vector_initialization)]
pub unsafe fn get_program_log(prog: sys::nvrtcProgram) -> Result<Vec<c_char>, NvrtcError> {
let mut size: usize = 0;
sys::nvrtcGetProgramLogSize(prog, &mut size as *mut _).result()?;
let mut log_src: Vec<c_char> = Vec::with_capacity(size);
log_src.resize(size, 0);
sys::nvrtcGetProgramLog(prog, log_src.as_mut_ptr()).result()?;
Ok(log_src)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_program_no_opts() {
let prog = create_program("extern \"C\" __global__ void kernel() { }").unwrap();
unsafe { compile_program::<&str>(prog, &[]) }.unwrap();
unsafe { destroy_program(prog) }.unwrap();
}
#[test]
fn test_compile_program_1_opt() {
let prog = create_program("extern \"C\" __global__ void kernel() { }").unwrap();
unsafe { compile_program(prog, &["--ftz=true"]) }.unwrap();
unsafe { destroy_program(prog) }.unwrap();
}
#[test]
fn test_compile_program_2_opt() {
let prog = create_program("extern \"C\" __global__ void kernel() { }").unwrap();
unsafe { compile_program(prog, &["--ftz=true", "--fmad=true"]) }.unwrap();
unsafe { destroy_program(prog) }.unwrap();
}
#[test]
fn test_compile_bad_program() {
let prog = create_program("extern \"C\" __global__ void kernel(").unwrap();
assert_eq!(
unsafe { compile_program::<&str>(prog, &[]) }.unwrap_err(),
NvrtcError(sys::nvrtcResult::NVRTC_ERROR_COMPILATION)
);
}
#[test]
fn test_get_ptx() {
const SRC: &str =
"extern \"C\" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}";
let prog = create_program(SRC).unwrap();
unsafe { compile_program::<&str>(prog, &[]) }.unwrap();
let ptx = unsafe { get_ptx(prog) }.unwrap();
assert!(!ptx.is_empty());
let log = unsafe { get_program_log(prog) }.unwrap();
assert!(!log.is_empty());
unsafe { destroy_program(prog) }.unwrap();
}
}
//! Safe abstractions around [crate::nvrtc::result] for compiling PTX files.
//!
//! Call [compile_ptx()] or [compile_ptx_with_opts()].
use super::{result, sys};
use core::ffi::{c_char, CStr};
use std::ffi::CString;
use std::{borrow::ToOwned, path::PathBuf, string::String, vec::Vec};
/// An opaque structure representing a compiled PTX program
/// output from [compile_ptx()] or [compile_ptx_with_opts()].
///
/// Can also be created from a [Ptx::from_file] and [Ptx::from_src]
#[derive(Debug, Clone)]
pub struct Ptx(pub(crate) PtxKind);
impl Ptx {
/// Creates a Ptx from a pre-compiled .ptx file.
pub fn from_file<P: Into<PathBuf>>(path: P) -> Self {
Self(PtxKind::File(path.into()))
}
/// Creates a Ptx from the source string of a pre-compiled .ptx
/// file.
pub fn from_src<S: Into<String>>(src: S) -> Self {
Self(PtxKind::Src(src.into()))
}
/// Get the compiled source as a string.
pub fn to_src(&self) -> String {
match &self.0 {
PtxKind::Image(bytes) => unsafe { CStr::from_ptr(bytes.as_ptr()) }
.to_str()
.expect("Unable to convert bytes to str.")
.to_owned(),
PtxKind::Src(src) => src.clone(),
PtxKind::File(path) => {
std::fs::read_to_string(path).expect("Unable to read ptx from file.")
}
}
}
}
impl<S: Into<String>> From<S> for Ptx {
fn from(value: S) -> Self {
Self::from_src(value)
}
}
#[derive(Debug, Clone)]
pub(crate) enum PtxKind {
/// An image created by [compile_ptx]
Image(Vec<c_char>),
/// Content of a pre compiled ptx file
Src(String),
/// Path to a compiled ptx
File(PathBuf),
}
/// Calls [compile_ptx_with_opts] with no options. `src` is the source string
/// of a `.cu` file.
///
/// Example:
/// ```rust
/// # use cudarc::nvrtc::*;
/// let ptx = compile_ptx("extern \"C\" __global__ void kernel() { }").unwrap();
/// ```
pub fn compile_ptx<S: AsRef<str>>(src: S) -> Result<Ptx, CompileError> {
compile_ptx_with_opts(src, Default::default())
}
/// Compiles `src` with the given `opts`. `src` is the source string of a `.cu` file.
///
/// Example:
/// ```rust
/// # use cudarc::nvrtc::*;
/// let opts = CompileOptions {
/// ftz: Some(true),
/// maxrregcount: Some(10),
/// ..Default::default()
/// };
/// let ptx = compile_ptx_with_opts("extern \"C\" __global__ void kernel() { }", opts).unwrap();
/// ```
pub fn compile_ptx_with_opts<S: AsRef<str>>(
src: S,
opts: CompileOptions,
) -> Result<Ptx, CompileError> {
let prog = Program::create(src)?;
prog.compile(opts)
}
pub(crate) struct Program {
prog: sys::nvrtcProgram,
}
impl Program {
pub(crate) fn create<S: AsRef<str>>(src: S) -> Result<Self, CompileError> {
let prog = result::create_program(src).map_err(CompileError::CreationError)?;
Ok(Self { prog })
}
pub(crate) fn compile(self, opts: CompileOptions) -> Result<Ptx, CompileError> {
let options = opts.build();
unsafe { result::compile_program(self.prog, &options) }.map_err(|e| {
let log_raw = unsafe { result::get_program_log(self.prog) }.unwrap();
let log_ptr = log_raw.as_ptr();
let log = unsafe { CStr::from_ptr(log_ptr) }.to_owned();
CompileError::CompileError {
nvrtc: e,
options,
log,
}
})?;
let image = unsafe { result::get_ptx(self.prog) }.map_err(CompileError::GetPtxError)?;
Ok(Ptx(PtxKind::Image(image)))
}
}
impl Drop for Program {
fn drop(&mut self) {
let prog = std::mem::replace(&mut self.prog, std::ptr::null_mut());
if !prog.is_null() {
unsafe { result::destroy_program(prog) }.unwrap()
}
}
}
/// Represents an error that happens during nvrtc compilation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompileError {
/// Error happened during [result::create_program()]
CreationError(result::NvrtcError),
/// Error happened during [result::compile_program()]
CompileError {
nvrtc: result::NvrtcError,
options: Vec<String>,
log: CString,
},
/// Error happened during [result::get_program_log()]
GetLogError(result::NvrtcError),
/// Error happened during [result::get_ptx()]
GetPtxError(result::NvrtcError),
/// Error happened during [result::destroy_program()]
DestroyError(result::NvrtcError),
}
#[cfg(feature = "std")]
impl std::fmt::Display for CompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CompileError {}
/// Flags you can pass to the nvrtc compiler.
/// See <https://docs.nvidia.com/cuda/nvrtc/index.html#group__options>
/// for all available flags and documentation for what they do.
///
/// All fields of this struct match one of the flags in the documentation.
/// if a field is `None` it will not be passed to the compiler.
///
/// All fields default to `None`.
///
/// *NOTE*: not all flags are currently supported.
///
/// Example:
/// ```rust
/// # use cudarc::nvrtc::*;
/// // "--ftz=true" will be passed to the compiler
/// let opts = CompileOptions {
/// ftz: Some(true),
/// ..Default::default()
/// };
/// ```
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)]
pub struct CompileOptions {
pub ftz: Option<bool>,
pub prec_sqrt: Option<bool>,
pub prec_div: Option<bool>,
pub fmad: Option<bool>,
pub use_fast_math: Option<bool>,
pub maxrregcount: Option<usize>,
pub include_paths: Vec<String>,
pub arch: Option<&'static str>,
}
impl CompileOptions {
pub(crate) fn build(self) -> Vec<String> {
let mut options: Vec<String> = Vec::new();
if let Some(v) = self.ftz {
options.push(std::format!("--ftz={v}"));
}
if let Some(v) = self.prec_sqrt {
options.push(std::format!("--prec-sqrt={v}"));
}
if let Some(v) = self.prec_div {
options.push(std::format!("--prec-div={v}"));
}
if let Some(v) = self.fmad {
options.push(std::format!("--fmad={v}"));
}
if let Some(true) = self.use_fast_math {
options.push("--fmad=true".into());
}
if let Some(count) = self.maxrregcount {
options.push(std::format!("--maxrregcount={count}"));
}
for path in self.include_paths {
options.push(std::format!("--include-path={path}"));
}
if let Some(arch) = self.arch {
options.push(std::format!("--gpu-architecture={arch}"))
}
options
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_no_opts() {
const SRC: &str =
"extern \"C\" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}";
compile_ptx_with_opts(SRC, Default::default()).unwrap();
}
#[test]
fn test_compile_options_build_none() {
let opts: CompileOptions = Default::default();
assert!(opts.build().is_empty());
}
#[test]
fn test_compile_options_build_ftz() {
let opts = CompileOptions {
ftz: Some(true),
..Default::default()
};
assert_eq!(&opts.build(), &["--ftz=true"]);
}
#[test]
fn test_compile_options_build_multi() {
let opts = CompileOptions {
prec_div: Some(false),
maxrregcount: Some(60),
..Default::default()
};
assert_eq!(&opts.build(), &["--prec-div=false", "--maxrregcount=60"]);
}
}
//! Bindings from "nvrtc.h" generated by rust-bindgen 0.60.1
#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub enum nvrtcResult {
NVRTC_SUCCESS = 0,
NVRTC_ERROR_OUT_OF_MEMORY = 1,
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2,
NVRTC_ERROR_INVALID_INPUT = 3,
NVRTC_ERROR_INVALID_PROGRAM = 4,
NVRTC_ERROR_INVALID_OPTION = 5,
NVRTC_ERROR_COMPILATION = 6,
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7,
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8,
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9,
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10,
NVRTC_ERROR_INTERNAL_ERROR = 11,
}
extern "C" {
pub fn nvrtcGetErrorString(result: nvrtcResult) -> *const core::ffi::c_char;
}
extern "C" {
pub fn nvrtcVersion(major: *mut core::ffi::c_int, minor: *mut core::ffi::c_int) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetNumSupportedArchs(numArchs: *mut core::ffi::c_int) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetSupportedArchs(supportedArchs: *mut core::ffi::c_int) -> nvrtcResult;
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct _nvrtcProgram {
_unused: [u8; 0],
}
pub type nvrtcProgram = *mut _nvrtcProgram;
extern "C" {
pub fn nvrtcCreateProgram(
prog: *mut nvrtcProgram,
src: *const core::ffi::c_char,
name: *const core::ffi::c_char,
numHeaders: core::ffi::c_int,
headers: *const *const core::ffi::c_char,
includeNames: *const *const core::ffi::c_char,
) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcDestroyProgram(prog: *mut nvrtcProgram) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcCompileProgram(
prog: nvrtcProgram,
numOptions: core::ffi::c_int,
options: *const *const core::ffi::c_char,
) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetPTXSize(prog: nvrtcProgram, ptxSizeRet: *mut usize) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetPTX(prog: nvrtcProgram, ptx: *mut core::ffi::c_char) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetCUBINSize(prog: nvrtcProgram, cubinSizeRet: *mut usize) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetCUBIN(prog: nvrtcProgram, cubin: *mut core::ffi::c_char) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetProgramLogSize(prog: nvrtcProgram, logSizeRet: *mut usize) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetProgramLog(prog: nvrtcProgram, log: *mut core::ffi::c_char) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcAddNameExpression(
prog: nvrtcProgram,
name_expression: *const core::ffi::c_char,
) -> nvrtcResult;
}
extern "C" {
pub fn nvrtcGetLoweredName(
prog: nvrtcProgram,
name_expression: *const core::ffi::c_char,
lowered_name: *mut *const core::ffi::c_char,
) -> nvrtcResult;
}
#include "nvrtc.h"
\ No newline at end of file
//! Exposes [CudaTypeName] which maps between rust type names
//! and the corresponding cuda kernel type names.
//!
//! For example, `f32` in rust corresponds to `float` in a cuda
//! kernel.
/// Maps a rust type to it's corresponding [CudaTypeName::NAME] in cuda c++ land.
pub trait CudaTypeName {
const NAME: &'static str;
}
macro_rules! cuda_type {
($RustTy:ty, $CudaTy:expr) => {
impl CudaTypeName for $RustTy {
const NAME: &'static str = $CudaTy;
}
};
}
cuda_type!(bool, "bool");
cuda_type!(i8, "char");
cuda_type!(i16, "short");
cuda_type!(i32, "int");
cuda_type!(i64, "long");
cuda_type!(isize, "intptr_t");
cuda_type!(u8, "unsigned char");
cuda_type!(u16, "unsigned short");
cuda_type!(u32, "unsigned int");
cuda_type!(u64, "unsigned long");
cuda_type!(usize, "size_t");
cuda_type!(f32, "float");
cuda_type!(f64, "double");
#[cfg(feature = "f16")]
cuda_type!(half::f16, "__half");
#[cfg(feature = "f16")]
cuda_type!(half::bf16, "__nv_bfloat16");
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment