gpu.rs 778 Bytes
Newer Older
1
pub fn get_cuda_capability() -> Option<(usize, usize)> {
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
    use pyo3::prelude::*;

    let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
        let torch = py.import_bound("torch.cuda")?;
        let get_device_capability = torch.getattr("get_device_capability")?;
        get_device_capability.call0()?.extract()
    };

    match pyo3::Python::with_gil(py_get_capability) {
        Ok((major, minor)) if major < 0 || minor < 0 => {
            tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
            None
        }
        Ok((major, minor)) => Some((major as usize, minor as usize)),
        Err(err) => {
            tracing::warn!("Cannot determine GPU compute capability: {}", err);
            None
        }
    }
}