Commit 25d2752f authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
use crate::backend::BackendDevice;
use crate::cpu_backend::CpuDevice;
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
fn to_cpu_storage(&self) -> CpuStorage;
}
impl<S: WithDType> NdArray for S {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(&[*self])
}
}
impl<S: WithDType, const N: usize> NdArray for &[S; N] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: WithDType> NdArray for &[S] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self)
}
}
impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((M, N)))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage_owned(self.concat())
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
for &[[[S; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::with_capacity(N1 * N2 * N3);
for i1 in 0..N1 {
for i2 in 0..N2 {
vec.extend(self[i1][i2])
}
}
S::to_cpu_storage_owned(vec)
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
for &[[[[S; N4]; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3, N4)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
for i1 in 0..N1 {
for i2 in 0..N2 {
for i3 in 0..N3 {
vec.extend(self[i1][i2][i3])
}
}
}
S::to_cpu_storage_owned(vec)
}
}
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
pub fn location(&self) -> DeviceLocation {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
matches!(self, Self::Cpu)
}
pub fn is_cuda(&self) -> bool {
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)
} else {
Ok(Self::Cpu)
}
}
pub(crate) fn rand_uniform_f64(
&self,
lo: f64,
up: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
lo: T,
up: T,
shape: &Shape,
) -> Result<Storage> {
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
}
pub(crate) fn rand_normal_f64(
&self,
mean: f64,
std: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn rand_normal<T: crate::FloatDType>(
&self,
mean: T,
std: T,
shape: &Shape,
) -> Result<Storage> {
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.ones_impl(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.zeros_impl(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
}
}
/// Pretty printing of tensors
/// This implementation should be in line with the PyTorch version.
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};
impl Tensor {
fn fmt_dt<T: WithDType + std::fmt::Display>(
&self,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(f, "Tensor[")?;
match self.dims() {
[] => {
if let Ok(v) = self.to_scalar::<T>() {
write!(f, "{v}")?
}
}
[s] if *s < 10 => {
if let Ok(vs) = self.to_vec1::<T>() {
for (i, v) in vs.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{v}")?;
}
}
}
dims => {
write!(f, "dims ")?;
for (i, d) in dims.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{d}")?;
}
}
}
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
}
}
impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f),
DType::F32 => self.fmt_dt::<f32>(f),
DType::F64 => self.fmt_dt::<f64>(f),
}
}
}
/// Options for Tensor pretty printing
#[derive(Debug, Clone)]
pub struct PrinterOptions {
pub precision: usize,
pub threshold: usize,
pub edge_items: usize,
pub line_width: usize,
pub sci_mode: Option<bool>,
}
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
std::sync::Mutex::new(PrinterOptions::const_default());
impl PrinterOptions {
// We cannot use the default trait as it's not const.
const fn const_default() -> Self {
Self {
precision: 4,
threshold: 1000,
edge_items: 3,
line_width: 80,
sci_mode: None,
}
}
}
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
&PRINT_OPTS
}
pub fn set_print_options(options: PrinterOptions) {
*PRINT_OPTS.lock().unwrap() = options
}
pub fn set_print_options_default() {
*PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default()
}
pub fn set_print_options_short() {
*PRINT_OPTS.lock().unwrap() = PrinterOptions {
precision: 2,
threshold: 1000,
edge_items: 2,
line_width: 80,
sci_mode: None,
}
}
pub fn set_print_options_full() {
*PRINT_OPTS.lock().unwrap() = PrinterOptions {
precision: 4,
threshold: usize::MAX,
edge_items: 3,
line_width: 80,
sci_mode: None,
}
}
pub fn set_line_width(line_width: usize) {
PRINT_OPTS.lock().unwrap().line_width = line_width
}
pub fn set_precision(precision: usize) {
PRINT_OPTS.lock().unwrap().precision = precision
}
pub fn set_edge_items(edge_items: usize) {
PRINT_OPTS.lock().unwrap().edge_items = edge_items
}
pub fn set_threshold(threshold: usize) {
PRINT_OPTS.lock().unwrap().threshold = threshold
}
pub fn set_sci_mode(sci_mode: Option<bool>) {
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
}
struct FmtSize {
current_size: usize,
}
impl FmtSize {
fn new() -> Self {
Self { current_size: 0 }
}
fn final_size(self) -> usize {
self.current_size
}
}
impl std::fmt::Write for FmtSize {
fn write_str(&mut self, s: &str) -> std::fmt::Result {
self.current_size += s.len();
Ok(())
}
}
trait TensorFormatter {
type Elem: WithDType;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result;
fn max_width(&self, to_display: &Tensor) -> usize {
let mut max_width = 1;
if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) {
for &v in vs.iter() {
let mut fmt_size = FmtSize::new();
let _res = self.fmt(v, 1, &mut fmt_size);
max_width = usize::max(max_width, fmt_size.final_size())
}
}
max_width
}
fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f)?;
for _ in 0..i {
write!(f, " ")?
}
Ok(())
}
fn fmt_tensor(
&self,
t: &Tensor,
indent: usize,
max_w: usize,
summarize: bool,
po: &PrinterOptions,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let dims = t.dims();
let edge_items = po.edge_items;
write!(f, "[")?;
match dims {
[] => {
if let Ok(v) = t.to_scalar::<Self::Elem>() {
self.fmt(v, max_w, f)?
}
}
[v] if summarize && *v > 2 * edge_items => {
if let Ok(vs) = t
.narrow(0, 0, edge_items)
.and_then(|t| t.to_vec1::<Self::Elem>())
{
for v in vs.into_iter() {
self.fmt(v, max_w, f)?;
write!(f, ", ")?;
}
}
write!(f, "...")?;
if let Ok(vs) = t
.narrow(0, v - edge_items, edge_items)
.and_then(|t| t.to_vec1::<Self::Elem>())
{
for v in vs.into_iter() {
write!(f, ", ")?;
self.fmt(v, max_w, f)?;
}
}
}
[_] => {
let elements_per_line = usize::max(1, po.line_width / (max_w + 2));
if let Ok(vs) = t.to_vec1::<Self::Elem>() {
for (i, v) in vs.into_iter().enumerate() {
if i > 0 {
if i % elements_per_line == 0 {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
} else {
write!(f, ", ")?;
}
}
self.fmt(v, max_w, f)?
}
}
}
_ => {
if summarize && dims[0] > 2 * edge_items {
for i in 0..edge_items {
match t.get(i) {
Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
Err(e) => write!(f, "{e:?}")?,
}
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
write!(f, "...")?;
Self::write_newline_indent(indent, f)?;
for i in dims[0] - edge_items..dims[0] {
match t.get(i) {
Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
Err(e) => write!(f, "{e:?}")?,
}
if i + 1 != dims[0] {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
}
} else {
for i in 0..dims[0] {
match t.get(i) {
Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?,
Err(e) => write!(f, "{e:?}")?,
}
if i + 1 != dims[0] {
write!(f, ",")?;
Self::write_newline_indent(indent, f)?
}
}
}
}
}
write!(f, "]")?;
Ok(())
}
}
struct FloatFormatter<S: WithDType> {
int_mode: bool,
sci_mode: bool,
precision: usize,
_phantom: std::marker::PhantomData<S>,
}
impl<S> FloatFormatter<S>
where
S: WithDType + num_traits::Float + std::fmt::Display,
{
fn new(t: &Tensor, po: &PrinterOptions) -> Result<Self> {
let mut int_mode = true;
let mut sci_mode = false;
// Rather than containing all values, this should only include
// values that end up being displayed according to [threshold].
let values = t
.flatten_all()?
.to_vec1()?
.into_iter()
.filter(|v: &S| v.is_finite() && !v.is_zero())
.collect::<Vec<_>>();
if !values.is_empty() {
let mut nonzero_finite_min = S::max_value();
let mut nonzero_finite_max = S::min_value();
for &v in values.iter() {
let v = v.abs();
if v < nonzero_finite_min {
nonzero_finite_min = v
}
if v > nonzero_finite_max {
nonzero_finite_max = v
}
}
for &value in values.iter() {
if value.ceil() != value {
int_mode = false;
break;
}
}
if let Some(v1) = S::from(1000.) {
if let Some(v2) = S::from(1e8) {
if let Some(v3) = S::from(1e-4) {
sci_mode = nonzero_finite_max / nonzero_finite_min > v1
|| nonzero_finite_max > v2
|| nonzero_finite_min < v3
}
}
}
}
match po.sci_mode {
None => {}
Some(v) => sci_mode = v,
}
Ok(Self {
int_mode,
sci_mode,
precision: po.precision,
_phantom: std::marker::PhantomData,
})
}
}
impl<S> TensorFormatter for FloatFormatter<S>
where
S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp,
{
type Elem = S;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
if self.sci_mode {
write!(
f,
"{v:width$.prec$e}",
v = v,
width = max_w,
prec = self.precision
)
} else if self.int_mode {
if v.is_finite() {
write!(f, "{v:width$.0}.", v = v, width = max_w - 1)
} else {
write!(f, "{v:max_w$.0}")
}
} else {
write!(
f,
"{v:width$.prec$}",
v = v,
width = max_w,
prec = self.precision
)
}
}
}
struct IntFormatter<S: WithDType> {
_phantom: std::marker::PhantomData<S>,
}
impl<S: WithDType> IntFormatter<S> {
fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<S> TensorFormatter for IntFormatter<S>
where
S: WithDType + std::fmt::Display,
{
type Elem = S;
fn fmt<T: std::fmt::Write>(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result {
write!(f, "{v:max_w$}")
}
}
fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result<Tensor> {
let dims = t.dims();
if dims.is_empty() {
Ok(t.clone())
} else if dims.len() == 1 {
if dims[0] > 2 * edge_items {
Tensor::cat(
&[
t.narrow(0, 0, edge_items)?,
t.narrow(0, dims[0] - edge_items, edge_items)?,
],
0,
)
} else {
Ok(t.clone())
}
} else if dims[0] > 2 * edge_items {
let mut vs: Vec<_> = (0..edge_items)
.map(|i| get_summarized_data(&t.get(i)?, edge_items))
.collect::<Result<Vec<_>>>()?;
for i in (dims[0] - edge_items)..dims[0] {
vs.push(get_summarized_data(&t.get(i)?, edge_items)?)
}
Tensor::cat(&vs, 0)
} else {
let vs: Vec<_> = (0..dims[0])
.map(|i| get_summarized_data(&t.get(i)?, edge_items))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&vs, 0)
}
}
impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let po = PRINT_OPTS.lock().unwrap();
let summarize = self.elem_count() > po.threshold;
let to_display = if summarize {
match get_summarized_data(self, po.edge_items) {
Ok(v) => v,
Err(err) => return write!(f, "{err:?}"),
}
} else {
self.clone()
};
match self.dtype() {
DType::U8 => {
let tf: IntFormatter<u8> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::U32 => {
let tf: IntFormatter<u32> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I64 => {
let tf: IntFormatter<i64> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
DType::F16 => {
if let Ok(tf) = FloatFormatter::<f16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
DType::F64 => {
if let Ok(tf) = FloatFormatter::<f64>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
DType::F32 => {
if let Ok(tf) = FloatFormatter::<f32>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
}
};
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(
f,
"Tensor[{:?}, {}{}]",
self.dims(),
self.dtype().as_str(),
device_str
)
}
}
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
// Unsigned 8 bits integer.
U8,
// Unsigned 32 bits integer.
U32,
// Signed 64 bits integer.
I64,
// Brain floating-point using half precision (16 bits).
BF16,
// Floating-point using half precision (16 bits).
F16,
// Floating-point using single precision (32 bits).
F32,
// Floating-point using double precision (64 bits).
F64,
}
#[derive(Debug, PartialEq, Eq)]
pub struct DTypeParseError(String);
impl std::fmt::Display for DTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "cannot parse '{}' as a dtype", self.0)
}
}
impl std::error::Error for DTypeParseError {}
impl std::str::FromStr for DType {
type Err = DTypeParseError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"u8" => Ok(Self::U8),
"u32" => Ok(Self::U32),
"i64" => Ok(Self::I64),
"bf16" => Ok(Self::BF16),
"f16" => Ok(Self::F16),
"f32" => Ok(Self::F32),
"f64" => Ok(Self::F64),
_ => Err(DTypeParseError(s.to_string())),
}
}
}
impl DType {
/// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
Self::U32 => "u32",
Self::I64 => "i64",
Self::BF16 => "bf16",
Self::F16 => "f16",
Self::F32 => "f32",
Self::F64 => "f64",
}
}
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
}
pub fn is_int(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => true,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
}
}
pub fn is_float(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => false,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
}
}
}
pub trait WithDType:
Sized
+ Copy
+ num_traits::NumAssign
+ std::cmp::PartialOrd
+ std::fmt::Display
+ 'static
+ Send
+ Sync
+ crate::cpu::kernels::VecOps
{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
Self::to_cpu_storage_owned(data.to_vec())
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
}
macro_rules! with_dtype {
($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
impl WithDType for $ty {
const DTYPE: DType = DType::$dtype;
fn from_f64(v: f64) -> Self {
$from_f64(v)
}
fn to_f64(self) -> f64 {
$to_f64(self)
}
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::$dtype(data)
}
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}
.bt()),
}
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}
.bt()),
}
}
}
};
}
use half::{bf16, f16};
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
pub trait IntDType: WithDType {
fn is_true(&self) -> bool;
fn as_usize(&self) -> usize;
}
impl IntDType for i64 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
impl IntDType for u8 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
pub trait FloatDType: WithDType {}
impl FloatDType for f16 {}
impl FloatDType for bf16 {}
impl FloatDType for f32 {}
impl FloatDType for f64 {}
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct CudaDevice;
#[derive(Debug)]
pub struct CudaStorage;
macro_rules! fail {
() => {
unimplemented!("cuda support has not been enabled, add `cuda` feature to enable.")
};
}
impl crate::backend::BackendStorage for CudaStorage {
type Device = CudaDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
}
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
pub lhs_l: Layout,
pub rhs_l: Layout,
pub bmnk: (usize, usize, usize, usize),
pub msg: &'static str,
}
/// Main library error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DTypeMismatchBinaryOp {
lhs: DType,
rhs: DType,
op: &'static str,
},
#[error("unsupported dtype {0:?} for op {1}")]
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
op: &'static str,
},
#[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
DuplicateDimIndex {
shape: Shape,
dims: Vec<usize>,
op: &'static str,
},
// === Shape Errors ===
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
UnexpectedNumberOfDims {
expected: usize,
got: usize,
shape: Shape,
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedShape {
msg: String,
expected: Shape,
got: Shape,
},
#[error(
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
)]
ShapeMismatch { buffer_size: usize, shape: Shape },
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
rhs: Shape,
op: &'static str,
},
#[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
ShapeMismatchCat {
dim: usize,
first_shape: Shape,
n: usize,
nth_shape: Shape,
},
#[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
ShapeMismatchSplit {
shape: Shape,
dim: usize,
n_parts: usize,
},
#[error("{op} can only be performed on a single dimension")]
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
#[error("empty tensor for {op}")]
EmptyTensor { op: &'static str },
// === Device Errors ===
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
lhs: DeviceLocation,
rhs: DeviceLocation,
op: &'static str,
},
// === Op Specific Errors ===
#[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
NarrowInvalidArgs {
shape: Shape,
dim: usize,
start: usize,
len: usize,
msg: &'static str,
},
#[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
Conv1dInvalidArgs {
inp_shape: Shape,
k_shape: Shape,
padding: usize,
stride: usize,
msg: &'static str,
},
#[error("{op} invalid index {index} with dim size {size}")]
InvalidIndex {
op: &'static str,
index: usize,
size: usize,
},
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
#[error("cannot set variable {msg}")]
CannotSetVar { msg: &'static str },
// Box indirection to avoid large variant.
#[error("{0:?}")]
MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
// === Other Errors ===
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
// === Wrapped Errors ===
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
#[error("npy/npz error {0}")]
Npy(String),
/// Zip file format error.
#[error(transparent)]
Zip(#[from] zip::result::ZipError),
/// Integer parse error.
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),
/// SafeTensor error.
#[error(transparent)]
SafeTensor(#[from] safetensors::SafeTensorError),
#[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
WithPath {
inner: Box<Self>,
path: std::path::PathBuf,
},
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
backtrace: Box<std::backtrace::Backtrace>,
},
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Msg(err.to_string()).bt()
}
pub fn bt(self) -> Self {
let backtrace = std::backtrace::Backtrace::capture();
match backtrace.status() {
std::backtrace::BacktraceStatus::Disabled
| std::backtrace::BacktraceStatus::Unsupported => self,
_ => Self::WithBacktrace {
inner: Box::new(self),
backtrace: Box::new(backtrace),
},
}
}
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
Self::WithPath {
inner: Box::new(self),
path: p.as_ref().to_path_buf(),
}
}
}
#[macro_export]
macro_rules! bail {
($msg:literal $(,)?) => {
return Err($crate::Error::Msg(format!($msg).into()).bt())
};
($err:expr $(,)?) => {
return Err($crate::Error::Msg(format!($err).into()).bt())
};
($fmt:expr, $($arg:tt)*) => {
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
};
}
pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
match (r1, r2) {
(Ok(r1), Ok(r2)) => Ok((r1, r2)),
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e),
}
}
use crate::{Error, Tensor};
use std::ops::{
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
impl Tensor {
/// Intended to be use by the trait `.i()`
///
/// ```
/// # use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.i(0..1)?;
/// assert_eq!(c.shape().dims(), &[1, 3]);
///
/// let c = a.i(0)?;
/// assert_eq!(c.shape().dims(), &[3]);
///
/// let c = a.i((.., ..2) )?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
///
/// let c = a.i((.., ..=2))?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
///
/// # Ok::<(), candle_core::Error>(())
/// ```
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
let mut x = self.clone();
let dims = self.shape().dims();
let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() {
x = match indexer {
TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound {
Bound::Included(n) => *n,
Bound::Excluded(n) => *n + 1,
Bound::Unbounded => 0,
};
let stop = match right_bound {
Bound::Included(n) => *n + 1,
Bound::Excluded(n) => *n,
Bound::Unbounded => dims[i],
};
let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
current_dim += 1;
out
}
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elements for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
}
impl From<usize> for TensorIndexer {
fn from(index: usize) -> Self {
TensorIndexer::Select(index)
}
}
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor
pub trait IndexOp<T> {
/// Returns a slicing iterator which are the chunks of data necessary to
/// reconstruct the desired tensor.
fn i(&self, index: T) -> Result<Tensor, Error>;
}
impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
macro_rules! index_op_tuple {
($($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);
use crate::{Error, Result, Shape};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Layout {
shape: Shape,
// The strides are given in number of elements and not in bytes.
stride: Vec<usize>,
start_offset: usize,
}
impl Layout {
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
Self {
shape,
stride,
start_offset,
}
}
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
let shape = shape.into();
let stride = shape.stride_contiguous();
Self {
shape,
stride,
start_offset,
}
}
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
Self::contiguous_with_offset(shape, 0)
}
pub fn dims(&self) -> &[usize] {
self.shape.dims()
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn stride(&self) -> &[usize] {
&self.stride
}
pub fn start_offset(&self) -> usize {
self.start_offset
}
/// Returns the appropriate start and stop offset if the data is stored in a C
/// contiguous (aka row major) way.
pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
if self.is_contiguous() {
let start_o = self.start_offset;
Some((start_o, start_o + self.shape.elem_count()))
} else {
None
}
}
/// Returns true if the data is stored in a C contiguous (aka row major) way.
/// Note that this does not implies that the start offset is 0 or that there are no extra
/// elements at the end of the storage.
pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride)
}
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
pub fn is_fortran_contiguous(&self) -> bool {
self.shape.is_fortran_contiguous(&self.stride)
}
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
dim: dim as i32,
op: "narrow",
}
.bt())?
}
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape.clone(),
dim,
start,
len,
msg: "start + len > dim_len",
}
.bt())?
}
let mut dims = dims.to_vec();
dims[dim] = len;
Ok(Self {
shape: Shape::from(dims),
stride: self.stride.clone(),
start_offset: self.start_offset + self.stride[dim] * start,
})
}
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
let rank = self.shape.rank();
if rank <= dim1 || rank <= dim2 {
Err(Error::UnexpectedNumberOfDims {
expected: usize::max(dim1, dim2),
got: rank,
shape: self.shape().clone(),
}
.bt())?
}
let mut stride = self.stride().to_vec();
let mut dims = self.shape().dims().to_vec();
dims.swap(dim1, dim2);
stride.swap(dim1, dim2);
Ok(Self {
shape: Shape::from(dims),
stride,
start_offset: self.start_offset,
})
}
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
idxs
)
}
let stride = self.stride();
let dims = self.shape().dims();
let mut perm_stride = stride.to_vec();
let mut perm_dims = dims.to_vec();
for (i, &idx) in idxs.iter().enumerate() {
perm_stride[i] = stride[idx];
perm_dims[i] = dims[idx];
}
Ok(Self {
shape: Shape::from(perm_dims),
stride: perm_stride,
start_offset: self.start_offset,
})
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
}
.bt());
}
let added_dims = shape.rank() - self.shape().rank();
let mut stride = vec![0; added_dims];
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
.iter()
.zip(self.dims().iter().zip(self.stride()))
{
let s = if dst_dim == src_dim {
src_stride
} else if src_dim != 1 {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
}
.bt());
} else {
0
};
stride.push(s)
}
Ok(Self {
shape,
stride,
start_offset: self.start_offset,
})
}
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::from_layout(self)
}
pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
let mut block_len = 1;
let mut contiguous_dims = 0; // These are counted from the right.
for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
if stride != block_len {
break;
}
block_len *= dim;
contiguous_dims += 1;
}
let index_dims = self.dims().len() - contiguous_dims;
if index_dims == 0 {
crate::StridedBlocks::SingleBlock {
start_offset: self.start_offset,
len: block_len,
}
} else {
let block_start_index = crate::StridedIndex::new(
&self.dims()[..index_dims],
&self.stride[..index_dims],
self.start_offset,
);
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
}
}
}
// Returns the contiguous offsets with broadcast if applicable.
pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
let mut left_broadcast = 1;
let mut right_broadcast = 1;
let strides = self.stride();
let dims = self.dims();
let mut start_cont = 0;
let mut end_cont = dims.len();
for (&s, &d) in strides.iter().zip(dims.iter()) {
if s != 0 {
break;
}
start_cont += 1;
left_broadcast *= d;
}
if start_cont == dims.len() {
return Some(ContiguousOffsetsWithBroadcast {
start: self.start_offset,
len: 1,
left_broadcast,
right_broadcast: 1,
});
}
for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
if s != 0 {
break;
}
end_cont -= 1;
right_broadcast *= d;
}
// Check that the inner dims are contiguous
let strides = &strides[start_cont..end_cont];
let dims = &dims[start_cont..end_cont];
let mut len = 1;
for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
if stride != len {
return None;
}
len *= dim;
}
Some(ContiguousOffsetsWithBroadcast {
start: self.start_offset,
len,
left_broadcast,
right_broadcast,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContiguousOffsetsWithBroadcast {
pub start: usize,
pub len: usize,
pub left_broadcast: usize,
pub right_broadcast: usize,
}
//! ML framework for Rust
//!
//! ```rust
//! use candle_core::{Tensor, DType, Device};
//! # use candle_core::Error;
//! # fn main() -> Result<(), Error>{
//!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
//!
//! let c = a.matmul(&b)?;
//! # Ok(())}
//! ```
//!
//! ## Features
//!
//! - Simple syntax (looks and feels like PyTorch)
//! - CPU and Cuda backends (and M1 support)
//! - Enable serverless (CPU) small and fast deployments
//! - Model training
//! - Distributed computing (NCCL).
//! - Models out of the box (Llama, Whisper, Falcon, ...)
//!
//! ## FAQ
//!
//! - Why Candle?
//!
//! Candle stems from the need to reduce binary size in order to *enable serverless*
//! possible by making the whole engine smaller than PyTorch very large library volume
//!
//! And simply *removing Python* from production workloads.
//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
#[cfg(feature = "accelerate")]
mod accelerate;
pub mod backend;
pub mod backprop;
pub mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
mod custom_op;
mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
mod mkl;
pub mod npy;
pub mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod storage;
mod strided_index;
mod tensor;
mod tensor_cat;
pub mod test_utils;
pub mod utils;
mod variable;
#[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn;
pub use cpu_backend::CpuStorage;
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub trait ToUsize2 {
fn to_usize2(self) -> (usize, usize);
}
impl ToUsize2 for usize {
fn to_usize2(self) -> (usize, usize) {
(self, self)
}
}
impl ToUsize2 for (usize, usize) {
fn to_usize2(self) -> (usize, usize) {
self
}
}
// A simple trait defining a module with forward method using a single argument.
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}
impl<M: Module> Module for Option<&M> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
None => Ok(xs.clone()),
Some(m) => m.forward(xs),
}
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}
use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use super::MetalError;
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
pub(crate) fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
/// the device itself.
pub(crate) id: DeviceId,
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
descriptor.set_output_url(path);
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
Ok(())
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
(size - 1).next_power_of_two() as NSUInteger
}
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
}
}
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
pub enum LockError {
#[error("{0}")]
Poisoned(String),
#[error("Would block")]
WouldBlock,
}
impl<T> From<TryLockError<T>> for MetalError {
fn from(value: TryLockError<T>) -> Self {
match value {
TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())),
TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock),
}
}
}
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
#[error("{0:?}")]
LockError(LockError),
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
/// The actual buffer containing the data.
buffer: Arc<metal::Buffer>,
/// a reference to the device owning this buffer
device: MetalDevice,
/// The count of allocated elements in the buffer
count: usize,
/// The dtype is kept since buffers are untyped.
dtype: DType,
}
impl BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Ok(self.clone())
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> &Self::Device {
&self.device
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)),
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)),
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)),
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)),
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)),
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)),
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let el = shape.elem_count();
let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype, "affine")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
name,
el,
src,
&buffer,
mul as f32,
add as f32,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided",
DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.dims(),
src,
layout.stride(),
&buffer,
mul as f32,
add as f32,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(buffer, device.clone(), el, dtype))
}
fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let el = shape.elem_count();
let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype, "powf")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
DType::BF16 => "powf_bf16",
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
};
candle_metal_kernels::call_powf(
&device.device,
&command_buffer,
&device.kernels,
name,
el,
src,
&buffer,
pow as f32,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "powf_f32_strided",
DType::F16 => "powf_f16_strided",
DType::BF16 => "powf_bf16_strided",
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
};
candle_metal_kernels::call_powf_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.dims(),
src,
layout.stride(),
&buffer,
pow as f32,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(buffer, device.clone(), el, dtype))
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let el = shape.elem_count();
let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype, "elu")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
DType::BF16 => "elu_bf16",
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
};
candle_metal_kernels::call_elu(
&device.device,
&command_buffer,
&device.kernels,
name,
el,
src,
&buffer,
alpha as f32,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "elu_f32_strided",
DType::F16 => "elu_f16_strided",
DType::BF16 => "elu_bf16_strided",
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
};
candle_metal_kernels::call_elu_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.dims(),
src,
layout.stride(),
&buffer,
alpha as f32,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(buffer, device.clone(), el, dtype))
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
let mut dst_el: usize = 1;
for (dim_idx, &d) in src_dims.iter().enumerate() {
if !sum_dims.contains(&dim_idx) {
dst_el *= d;
dims.push(d);
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
&dims,
&stride,
dst_el,
src,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device, dst_el, dtype))
}
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
let name = match op {
CmpOp::Eq => "eq",
CmpOp::Ne => "ne",
CmpOp::Le => "le",
CmpOp::Ge => "ge",
CmpOp::Lt => "lt",
CmpOp::Gt => "gt",
};
self.binary(name, rhs, lhs_l, rhs_l)
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::U8, DType::F16) => "cast_u8_f16",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::I64) => "cast_f32_i64",
(DType::F32, DType::U32) => "cast_f32_u32",
(DType::F32, DType::U8) => "cast_f32_u8",
(DType::I64, DType::BF16) => "cast_i64_bf16",
(DType::I64, DType::F16) => "cast_i64_f16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::I64, DType::U32) => "cast_i64_u32",
(DType::I64, DType::U8) => "cast_i64_u8",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::F16, DType::I64) => "cast_f16_i64",
(DType::F16, DType::U32) => "cast_f16_u32",
(DType::F16, DType::U8) => "cast_f16_u8",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
(DType::BF16, DType::I64) => "cast_bf16_i64",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(DType::U8, DType::F32) => "cast_u8_f32_strided",
(DType::U8, DType::I64) => "cast_u8_i64_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(left, right) => {
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
}
};
candle_metal_kernels::call_cast_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
src,
layout.stride(),
&buffer,
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("to_dtype");
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL);
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous::abs::HALF,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
("uceil", DType::F16) => contiguous::ceil::HALF,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("ucos", DType::F32) => contiguous::cos::FLOAT,
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
("uerf", DType::F16) => contiguous::erf::HALF,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
("uexp", DType::F16) => contiguous::exp::HALF,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
("ufloor", DType::F16) => contiguous::floor::HALF,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
("ulog", DType::F16) => contiguous::log::HALF,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ulog", DType::BF16) => contiguous::log::BFLOAT,
("uneg", DType::F16) => contiguous::neg::HALF,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
("urecip", DType::F16) => contiguous::recip::HALF,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
("urelu", DType::F16) => contiguous::relu::HALF,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
("uround", DType::F16) => contiguous::round::HALF,
("uround", DType::F32) => contiguous::round::FLOAT,
("uround", DType::BF16) => contiguous::round::BFLOAT,
("usilu", DType::F16) => contiguous::silu::HALF,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
("usin", DType::F16) => contiguous::sin::HALF,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usin", DType::BF16) => contiguous::sin::BFLOAT,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
("utanh", DType::F16) => contiguous::tanh::HALF,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
("usign", DType::F16) => contiguous::sign::HALF,
("usign", DType::F32) => contiguous::sign::FLOAT,
("usign", DType::BF16) => contiguous::sign::BFLOAT,
("usign", DType::I64) => contiguous::sign::I64,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("usilu", DType::F32) => strided::silu::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("usilu", DType::F16) => strided::silu::HALF,
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
};
let dst = BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
src,
layout.stride(),
dst,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
}
fn where_cond(
&self,
layout: &Layout,
t: &Self,
t_l: &Layout,
f: &Self,
f_l: &Layout,
) -> Result<Self> {
let device = self.device.clone();
let shape = t_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
let buffer = self.device.new_buffer(el, dtype, "where")?;
let command_buffer = self.device.command_buffer()?;
if t.dtype() != f.dtype() {
crate::bail!(
"Invalid where: different dtypes for values {:?} != {:?}",
t.dtype(),
f.dtype()
);
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
(DType::U8, DType::U32) => "where_u8_u32",
(DType::U8, DType::U8) => "where_u8_u8",
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
let t = buffer_o(&t.buffer, t_l, t.dtype);
let f = buffer_o(&f.buffer, f_l, f.dtype);
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
dims,
src,
layout.stride(),
t,
t_l.stride(),
f,
f_l.stride(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device, el, dtype))
}
fn conv1d(
&self,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv1D,
) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let strides = layout.stride();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let k_size = params.k_size;
let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = dims[0] * l_out * dims[1] * k_size;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
src,
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
count: dst_el,
dtype: self.dtype,
};
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose1d(
&self,
layout: &Layout,
k: &Self,
k_layout: &Layout,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn conv2d(
&self,
layout: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConv2D,
) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let stride = params.stride;
let dilation = params.dilation;
let padding = params.padding;
let h_k = params.k_h;
let w_k = params.k_w;
let h = dims[2];
let w = dims[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
let dst = self
.device
.new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
DType::F16 => "im2col_f16",
DType::BF16 => "im2col_bf16",
DType::U8 => "im2col_u8",
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
src,
&dst,
)
.map_err(MetalError::from)?;
let col = Self {
buffer: dst,
device,
count: dst_el,
dtype: self.dtype,
};
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConvTranspose2D,
) -> Result<Self> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let (out_w, out_h) = (params.out_w(), params.out_h());
let dst_el = params.c_out * out_w * out_h * params.b_size;
let dims = l.dims();
if dims.len() != 4 {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4")
}
let k_dims = kernel_l.dims();
if k_dims.len() != 4 {
crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4")
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose2d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose2d_f32",
DType::F16 => "conv_transpose2d_f16",
DType::BF16 => "conv_transpose2d_bf16",
dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
CallConvTranspose2dCfg {
dilation: params.dilation,
stride: params.stride,
padding: params.padding,
output_padding: params.output_padding,
c_out: params.c_out,
out_h,
out_w,
b_size: params.b_size,
input_dims: l.dims(),
input_stride: l.stride(),
kernel_dims: kernel_l.dims(),
kernel_stride: kernel_l.stride(),
input_offset: l.start_offset() * self.dtype.size_in_bytes(),
kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(),
},
&self.buffer,
&kernel.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn avg_pool2d(
&self,
inp_l: &Layout,
(w_k, h_k): (usize, usize),
(w_stride, h_stride): (usize, usize),
) -> Result<Self> {
let shape = inp_l.shape();
let (b_size, channels, width, height) = shape.dims4()?;
let strides = inp_l.stride();
let name = match self.dtype {
DType::F32 => "avg_pool2d_f32",
DType::F16 => "avg_pool2d_f16",
DType::BF16 => "avg_pool2d_bf16",
DType::U8 => "avg_pool2d_u8",
DType::U32 => "avg_pool2d_u32",
dtype => crate::bail!("Metal avg_pool2d {dtype:?} not implemented"),
};
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;
let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?;
let command_buffers = self.device.command_buffer()?;
candle_metal_kernels::call_pool2d(
&self.device.device,
&command_buffers,
&self.device.kernels,
name,
inp_l.dims(),
strides,
out_w,
out_h,
w_k,
h_k,
w_stride,
h_stride,
&self.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn max_pool2d(
&self,
inp_l: &Layout,
(w_k, h_k): (usize, usize),
(w_stride, h_stride): (usize, usize),
) -> Result<Self> {
let shape = inp_l.shape();
let (b_size, channels, width, height) = shape.dims4()?;
let strides = inp_l.stride();
let name = match self.dtype {
DType::F32 => "max_pool2d_f32",
DType::F16 => "max_pool2d_f16",
DType::BF16 => "max_pool2d_bf16",
DType::U8 => "max_pool2d_u8",
DType::U32 => "max_pool2d_u32",
dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"),
};
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;
let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?;
let command_buffers = self.device.command_buffer()?;
candle_metal_kernels::call_pool2d(
&self.device.device,
&command_buffers,
&self.device.kernels,
name,
inp_l.dims(),
strides,
out_w,
out_h,
w_k,
h_k,
w_stride,
h_stride,
&self.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
crate::bail!("Metal upsample_nearest1d not implemented")
}
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
// let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let strides = inp_l.stride();
if dims.len() != 4 {
crate::bail!("unexpected input shape for upsample {dims:?}")
}
let name = match self.dtype {
DType::F32 => "upsample_nearest2d_f32",
DType::F16 => "upsample_nearest2d_f16",
DType::BF16 => "upsample_nearest2d_bf16",
DType::U8 => "upsample_nearest2d_u8",
DType::U32 => "upsample_nearest2d_u32",
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
};
let dst_el = out_w * out_h * dims[0] * dims[1];
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, inp_l, self.dtype);
candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
dims,
strides,
out_w,
out_h,
src,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
if !ids_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
};
let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_gather(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
}
fn scatter_add(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
(DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
(DType::I64, DType::F32) => "sa_i64_f32",
(DType::I64, DType::F16) => "sa_i64_f16",
(DType::I64, DType::BF16) => "sa_i64_bf16",
_ => Err(MetalError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
l.dims(),
dim,
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
Ok(acc)
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
if !ids_l.is_contiguous() {
crate::bail!("Metal index_select requires contiguous ids")
}
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.shape().elem_count();
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
}
fn index_add(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::I64, DType::BF16) => "ia_i64_bf16",
(DType::I64, DType::F16) => "ia_i64_f16",
(DType::I64, DType::F32) => "ia_i64_f32",
(DType::I64, DType::I64) => "ia_i64_i64",
(DType::I64, DType::U32) => "ia_i64_u32",
(DType::I64, DType::U8) => "ia_i64_u8",
(DType::U32, DType::BF16) => "ia_u32_bf16",
(DType::U32, DType::F16) => "ia_u32_f16",
(DType::U32, DType::F32) => "ia_u32_f32",
(DType::U32, DType::I64) => "ia_u32_i64",
(DType::U32, DType::U32) => "ia_u32_u32",
(DType::U32, DType::U8) => "ia_u32_u8",
(DType::U8, DType::BF16) => "ia_u8_bf16",
(DType::U8, DType::F16) => "ia_u8_f16",
(DType::U8, DType::F32) => "ia_u8_f32",
(DType::U8, DType::I64) => "ia_u8_i64",
(DType::U8, DType::U32) => "ia_u8_u32",
(DType::U8, DType::U8) => "ia_u8_u8",
_ => Err(MetalError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_add(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
l.dims(),
ids_l.dims(),
dim,
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
Ok(acc)
}
fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),
b * m * n,
self.dtype(),
))
}
fn copy2d(
&self,
dst: &mut Self,
d1: usize,
d2: usize,
src_s: usize,
dst_s: usize,
src_o: usize,
dst_o: usize,
) -> Result<()> {
if self.dtype() != dst.dtype() {
crate::bail!(
"copy2d with inconsistent dtypes {:?} {:?}",
self.dtype(),
dst.dtype()
)
}
let command_buffer = self.device.command_buffer()?;
if src_s == d2 && dst_s == d2 {
command_buffer.set_label("copy2d_contiguous");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("copy2d_contiguous");
let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let el_count = d1 * d2;
if el_count == 0 {
return Ok(());
}
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::copy2d::FLOAT,
DType::F16 => candle_metal_kernels::copy2d::HALF,
DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
DType::I64 => candle_metal_kernels::copy2d::I64,
DType::U32 => candle_metal_kernels::copy2d::U32,
DType::U8 => candle_metal_kernels::copy2d::U8,
dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_copy2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
&self.buffer,
&dst.buffer,
d1,
d2,
src_s,
dst_s,
src_o * self.dtype.size_in_bytes(),
dst_o * self.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy2d");
}
Ok(())
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer()?;
if src_l.is_contiguous() && self.dtype == dst.dtype() {
command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("copy_contiguous");
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
DType::I64 => candle_metal_kernels::unary::strided::copy::I64,
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, src_l, self.dtype);
let dst = BufferOffset {
buffer: &dst.buffer,
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
src,
src_l.stride(),
dst,
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
}
Ok(())
}
}
impl MetalStorage {
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, count: usize, dtype: DType) -> Self {
Self {
buffer,
device,
count,
dtype,
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn binary(
&self,
op: &'static str,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
("div", DType::I64) => (contiguous::div::I64, self.dtype),
("eq", DType::I64) => (contiguous::eq::I64, DType::U8),
("ne", DType::I64) => (contiguous::ne::I64, DType::U8),
("le", DType::I64) => (contiguous::le::I64, DType::U8),
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
("div", DType::U32) => (contiguous::div::U32, self.dtype),
("eq", DType::U32) => (contiguous::eq::U32, DType::U8),
("ne", DType::U32) => (contiguous::ne::U32, DType::U8),
("le", DType::U32) => (contiguous::le::U32, DType::U8),
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
("div", DType::U8) => (contiguous::div::U8, self.dtype),
("eq", DType::U8) => (contiguous::eq::U8, DType::U8),
("ne", DType::U8) => (contiguous::ne::U8, DType::U8),
("le", DType::U8) => (contiguous::le::U8, DType::U8),
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
};
let buffer = device.new_buffer(el_count, dtype, op)?;
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
lhs,
rhs,
&buffer,
)
.map_err(MetalError::from)?;
(buffer, dtype)
} else {
use candle_metal_kernels::binary::strided;
let (kernel_name, dtype) = match (op, self.dtype) {
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
("le", DType::F16) => (strided::le::HALF, DType::U8),
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
("bdiv", DType::I64) => (strided::div::I64, self.dtype),
("bminimum", DType::I64) => (strided::min::I64, self.dtype),
("bmaximum", DType::I64) => (strided::max::I64, self.dtype),
("eq", DType::I64) => (strided::eq::I64, DType::U8),
("ne", DType::I64) => (strided::ne::I64, DType::U8),
("le", DType::I64) => (strided::le::I64, DType::U8),
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),
("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
("bdiv", DType::U32) => (strided::div::U32, self.dtype),
("bminimum", DType::U32) => (strided::min::U32, self.dtype),
("bmaximum", DType::U32) => (strided::max::U32, self.dtype),
("eq", DType::U32) => (strided::eq::U32, DType::U8),
("ne", DType::U32) => (strided::ne::U32, DType::U8),
("le", DType::U32) => (strided::le::U32, DType::U8),
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),
("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
("bdiv", DType::U8) => (strided::div::U8, self.dtype),
("bminimum", DType::U8) => (strided::min::U8, self.dtype),
("bmaximum", DType::U8) => (strided::max::U8, self.dtype),
("eq", DType::U8) => (strided::eq::U8, DType::U8),
("ne", DType::U8) => (strided::ne::U8, DType::U8),
("le", DType::U8) => (strided::le::U8, DType::U8),
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),
(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}
};
let buffer = device.new_buffer(el_count, dtype, op)?;
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
lhs,
lhs_l.stride(),
rhs,
rhs_l.stride(),
&buffer,
)
.map_err(MetalError::from)?;
(buffer, dtype)
};
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
let buffer = self.device.new_buffer_managed(size)?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec(&buffer, self.count))
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self {
id: DeviceId::new(),
device,
command_queue,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
kernels,
seed,
})
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize,
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let size = shape.elem_count() * dtype.size_in_bytes();
let buffer = self.allocate_zeros(size)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let (count, buffer) = match storage {
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
};
Ok(Self::Storage::new(
buffer?,
self.clone(),
count,
storage.dtype(),
))
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<Self::Storage> {
self.storage_from_cpu_storage(&storage)
}
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
min: f64,
max: f64,
) -> Result<Self::Storage> {
let name = match dtype {
DType::F32 => "rand_uniform_f32",
DType::F16 => "rand_uniform_f16",
DType::BF16 => "rand_uniform_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_uniform(
&self.device,
&command_buffer,
&self.kernels,
name,
min as f32,
max as f32,
shape.elem_count(),
&self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
let name = match dtype {
DType::F32 => "rand_normal_f32",
DType::F16 => "rand_normal_f16",
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
mean as f32,
stddev as f32,
shape.elem_count(),
&self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn set_seed(&self, seed: u64) -> Result<()> {
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 1);
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
Ok(())
}
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
#![allow(dead_code)]
use libc::{c_char, c_double, c_float, c_int};
mod ffi {
use super::*;
extern "C" {
pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsExp(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdExp(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsLn(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdLn(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsSin(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdSin(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsCos(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdCos(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsSqrt(n: c_int, a: *const c_float, y: *mut c_float);
pub fn vdSqrt(n: c_int, a: *const c_double, y: *mut c_double);
pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsSub(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdSub(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsMul(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_float,
a: *const c_float,
lda: *const c_int,
b: *const c_float,
ldb: *const c_int,
beta: *const c_float,
c: *mut c_float,
ldc: *const c_int,
);
pub fn dgemm_(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_double,
a: *const c_double,
lda: *const c_int,
b: *const c_double,
ldb: *const c_int,
beta: *const c_double,
c: *mut c_double,
ldc: *const c_int,
);
pub fn hgemm_(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const half::f16,
a: *const half::f16,
lda: *const c_int,
b: *const half::f16,
ldb: *const c_int,
beta: *const half::f16,
c: *mut half::f16,
ldc: *const c_int,
);
}
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn sgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: &[f32],
lda: i32,
b: &[f32],
ldb: i32,
beta: f32,
c: &mut [f32],
ldc: i32,
) {
ffi::sgemm_(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn dgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32,
b: &[f64],
ldb: i32,
beta: f64,
c: &mut [f64],
ldc: i32,
) {
ffi::dgemm_(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn hgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: half::f16,
a: &[half::f16],
lda: i32,
b: &[half::f16],
ldb: i32,
beta: half::f16,
c: &mut [half::f16],
ldc: i32,
) {
ffi::hgemm_(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[inline]
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
// The vector functions from mkl can be performed in place by using the same array for input and
// output.
// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline]
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
let a_len = a.len();
let b_len = b.len();
let y_len = y.len();
if a_len != y_len || b_len != y_len {
panic!(
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
stringify!($fn_name)
);
}
unsafe { ffi::$mkl_name(a_len as i32, a.as_ptr(), b.as_ptr(), y.as_mut_ptr()) }
}
};
}
binary_op!(vs_add, f32, vsAdd);
binary_op!(vd_add, f64, vdAdd);
binary_op!(vs_sub, f32, vsSub);
binary_op!(vd_sub, f64, vdSub);
binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);
binary_op!(vs_max, f32, vsFmax);
binary_op!(vd_max, f64, vdFmax);
binary_op!(vs_min, f32, vsFmin);
binary_op!(vd_min, f64, vdFmin);
//! Numpy support for tensors.
//!
//! The spec for the npy format can be found in
//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
//! The functions from this module can be used to read tensors from npy/npz files
//! or write tensors to these files. A npy file contains a single tensor (unnamed)
//! whereas a npz file can contain multiple named tensors. npz files are also compressed.
//!
//! These two formats are easy to use in Python using the numpy library.
//!
//! ```python
//! import numpy as np
//! x = np.arange(10)
//!
//! # Write a npy file.
//! np.save("test.npy", x)
//!
//! # Read a value from the npy file.
//! x = np.load("test.npy")
//!
//! # Write multiple values to a npz file.
//! values = { "x": x, "x_plus_one": x + 1 }
//! np.savez("test.npz", **values)
//!
//! # Load multiple values from a npz file.
//! values = np.loadz("test.npz")
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Write};
use std::path::Path;
const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
const NPY_SUFFIX: &str = ".npy";
fn read_header<R: Read>(reader: &mut R) -> Result<String> {
let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
reader.read_exact(&mut magic_string)?;
if magic_string != NPY_MAGIC_STRING {
return Err(Error::Npy("magic string mismatch".to_string()));
}
let mut version = [0u8; 2];
reader.read_exact(&mut version)?;
let header_len_len = match version[0] {
1 => 2,
2 => 4,
otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
};
let mut header_len = vec![0u8; header_len_len];
reader.read_exact(&mut header_len)?;
let header_len = header_len
.iter()
.rev()
.fold(0_usize, |acc, &v| 256 * acc + v as usize);
let mut header = vec![0u8; header_len];
reader.read_exact(&mut header)?;
Ok(String::from_utf8_lossy(&header).to_string())
}
#[derive(Debug, PartialEq)]
struct Header {
descr: DType,
fortran_order: bool,
shape: Vec<usize>,
}
impl Header {
fn shape(&self) -> Shape {
Shape::from(self.shape.as_slice())
}
fn to_string(&self) -> Result<String> {
let fortran_order = if self.fortran_order { "True" } else { "False" };
let mut shape = self
.shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
let descr = match self.descr {
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::I64 => "i8",
DType::U32 => "u4",
DType::U8 => "u1",
};
if !shape.is_empty() {
shape.push(',')
}
Ok(format!(
"{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
))
}
// Hacky parser for the npy header, a typical example would be:
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
fn parse(header: &str) -> Result<Header> {
let header =
header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
let mut parts: Vec<String> = vec![];
let mut start_index = 0usize;
let mut cnt_parenthesis = 0i64;
for (index, c) in header.chars().enumerate() {
match c {
'(' => cnt_parenthesis += 1,
')' => cnt_parenthesis -= 1,
',' => {
if cnt_parenthesis == 0 {
parts.push(header[start_index..index].to_owned());
start_index = index + 1;
}
}
_ => {}
}
}
parts.push(header[start_index..].to_owned());
let mut part_map: HashMap<String, String> = HashMap::new();
for part in parts.iter() {
let part = part.trim();
if !part.is_empty() {
match part.split(':').collect::<Vec<_>>().as_slice() {
[key, value] => {
let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
let _ = part_map.insert(key.to_owned(), value.to_owned());
}
_ => return Err(Error::Npy(format!("unable to parse header {header}"))),
}
}
}
let fortran_order = match part_map.get("fortran_order") {
None => false,
Some(fortran_order) => match fortran_order.as_ref() {
"False" => false,
"True" => true,
_ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
},
};
let descr = match part_map.get("descr") {
None => return Err(Error::Npy("no descr in header".to_string())),
Some(descr) => {
if descr.is_empty() {
return Err(Error::Npy("empty descr".to_string()));
}
if descr.starts_with('>') {
return Err(Error::Npy(format!("little-endian descr {descr}")));
}
// the only supported types in tensor are:
// float64, float32, float16,
// complex64, complex128,
// int64, int32, int16, int8,
// uint8, and bool.
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
"e" | "f2" => DType::F16,
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
"q" | "i8" => DType::I64,
// "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8,
"B" | "u1" => DType::U8,
"I" | "u4" => DType::U32,
"?" | "b1" => DType::U8,
// "F" | "F4" => DType::C64,
// "D" | "F8" => DType::C128,
descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
}
}
};
let shape = match part_map.get("shape") {
None => return Err(Error::Npy("no shape in header".to_string())),
Some(shape) => {
let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
if shape.is_empty() {
vec![]
} else {
shape
.split(',')
.map(|v| v.trim().parse::<usize>())
.collect::<std::result::Result<Vec<_>, _>>()?
}
}
};
Ok(Header {
descr,
fortran_order,
shape,
})
}
}
impl Tensor {
// TODO: Add the possibility to read directly to a device?
pub(crate) fn from_reader<R: std::io::Read>(
shape: Shape,
dtype: DType,
reader: &mut R,
) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
let mut data_t = vec![bf16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F16 => {
let mut data_t = vec![f16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F32 => {
let mut data_t = vec![0f32; elem_count];
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F64 => {
let mut data_t = vec![0f64; elem_count];
reader.read_f64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::U8 => {
let mut data_t = vec![0u8; elem_count];
reader.read_exact(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::U32 => {
let mut data_t = vec![0u32; elem_count];
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::I64 => {
let mut data_t = vec![0i64; elem_count];
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
}
}
/// Reads a npy file and return the stored multi-dimensional array as a tensor.
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
let mut reader = File::open(path.as_ref())?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
Self::from_reader(header.shape(), header.descr, &mut reader)
}
/// Reads a npz file and returns the stored multi-dimensional arrays together with their names.
pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
let zip_reader = BufReader::new(File::open(path.as_ref())?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut result = vec![];
for i in 0..zip.len() {
let mut reader = zip.by_index(i)?;
let name = {
let name = reader.name();
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
};
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
result.push((name, s))
}
Ok(result)
}
/// Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
let zip_reader = BufReader::new(File::open(path.as_ref())?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut result = vec![];
for name in names.iter() {
let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
Ok(reader) => reader,
Err(_) => Err(Error::Npy(format!(
"no array for {name} in {:?}",
path.as_ref()
)))?,
};
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
result.push(s)
}
Ok(result)
}
fn write<T: Write>(&self, f: &mut T) -> Result<()> {
f.write_all(NPY_MAGIC_STRING)?;
f.write_all(&[1u8, 0u8])?;
let header = Header {
descr: self.dtype(),
fortran_order: false,
shape: self.dims().to_vec(),
};
let mut header = header.to_string()?;
let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
for _ in 0..pad % 16 {
header.push(' ')
}
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
self.write_bytes(f)
}
/// Writes a multi-dimensional array in the npy format.
pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
let mut f = File::create(path.as_ref())?;
self.write(&mut f)
}
/// Writes multiple multi-dimensional arrays using the npz format.
pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
ts: &[(S, T)],
path: P,
) -> Result<()> {
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
let options =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
for (name, tensor) in ts.iter() {
zip.start_file(format!("{}.npy", name.as_ref()), options)?;
tensor.as_ref().write(&mut zip)?
}
Ok(())
}
}
/// Lazy tensor loader.
pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl NpzTensors {
pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
let path = path.as_ref().to_owned();
let zip_reader = BufReader::new(File::open(&path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut index_per_name = HashMap::new();
for i in 0..zip.len() {
let file = zip.by_index(i)?;
let name = {
let name = file.name();
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
};
index_per_name.insert(name, i);
}
Ok(Self {
index_per_name,
path,
})
}
pub fn names(&self) -> Vec<&String> {
self.index_per_name.keys().collect()
}
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
/// reading the whole tensor data.
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
let index = match self.index_per_name.get(name) {
None => crate::bail!("cannot find tensor {name}"),
Some(index) => *index,
};
let zip_reader = BufReader::new(File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_index(index)?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
Ok((header.shape(), header.descr))
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let index = match self.index_per_name.get(name) {
None => return Ok(None),
Some(index) => *index,
};
// We hope that the file has not changed since first reading it.
let zip_reader = BufReader::new(File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_index(index)?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
Ok(Some(tensor))
}
}
#[cfg(test)]
mod tests {
use super::Header;
#[test]
fn parse() {
let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
assert_eq!(
Header::parse(h).unwrap(),
Header {
descr: crate::DType::F64,
fortran_order: false,
shape: vec![128]
}
);
let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
let h = Header::parse(h).unwrap();
assert_eq!(
h,
Header {
descr: crate::DType::F32,
fortran_order: true,
shape: vec![256, 1, 128]
}
);
assert_eq!(
h.to_string().unwrap(),
"{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
);
let h = Header {
descr: crate::DType::U32,
fortran_order: false,
shape: vec![],
};
assert_eq!(
h.to_string().unwrap(),
"{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
);
}
}
#![allow(clippy::redundant_closure_call)]
use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Le,
Ge,
Lt,
Gt,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Min,
Max,
ArgMin,
ArgMax,
}
impl ReduceOp {
pub(crate) fn name(&self) -> &'static str {
match self {
Self::ArgMax => "argmax",
Self::ArgMin => "argmin",
Self::Min => "min",
Self::Max => "max",
Self::Sum => "sum",
}
}
}
// These ops return the same type as their input type.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
Add,
Mul,
Sub,
Div,
Maximum,
Minimum,
}
// Unary ops with no argument
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
Exp,
Log,
Sin,
Cos,
Abs,
Neg,
Recip,
Sqr,
Sqrt,
Gelu,
GeluErf,
Erf,
Relu,
Silu,
Tanh,
Floor,
Ceil,
Round,
Sign,
}
#[derive(Clone)]
pub enum Op {
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
#[allow(dead_code)]
Conv1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
AvgPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
MaxPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
UpsampleNearest1D {
arg: Tensor,
target_size: usize,
},
UpsampleNearest2D {
arg: Tensor,
target_h: usize,
target_w: usize,
},
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
Affine {
arg: Tensor,
mul: f64,
add: f64,
},
ToDType(Tensor),
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
),
}
pub trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
fn i64(v1: i64) -> i64;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
const BF16_VEC: bool = false;
fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
const F16_VEC: bool = false;
fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
const F32_VEC: bool = false;
fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
const F64_VEC: bool = false;
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub trait BinaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
fn i64(v1: i64, v2: i64) -> i64;
const BF16_VEC: bool = false;
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
const F16_VEC: bool = false;
fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
const F32_VEC: bool = false;
fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
const F64_VEC: bool = false;
fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
const U8_VEC: bool = false;
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
const U32_VEC: bool = false;
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
const I64_VEC: bool = false;
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
}
pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Maximum;
pub(crate) struct Minimum;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Silu;
pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
pub(crate) struct Sign;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
impl BinaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
#[inline(always)]
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
#[inline(always)]
fn f16(v1: f16, v2: f16) -> f16 {
$e(v1, v2)
}
#[inline(always)]
fn f32(v1: f32, v2: f32) -> f32 {
$e(v1, v2)
}
#[inline(always)]
fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2)
}
#[inline(always)]
fn u8(v1: u8, v2: u8) -> u8 {
$e(v1, v2)
}
#[inline(always)]
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
#[inline(always)]
fn i64(v1: i64, v2: i64) -> i64 {
$e(v1, v2)
}
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
const F32_VEC: bool = true;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
const F64_VEC: bool = true;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs1, xs2, ys)
}
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs1, xs2, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::accelerate::$f32_vec(xs1, xs2, ys)
}
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::accelerate::$f64_vec(xs1, xs2, ys)
}
}
};
}
bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
bin_op!(
Minimum,
"minimum",
|v1, v2| if v1 > v2 { v2 } else { v1 },
vs_min,
vd_min
);
bin_op!(
Maximum,
"maximum",
|v1, v2| if v1 < v2 { v2 } else { v1 },
vs_max,
vd_max
);
#[allow(clippy::redundant_closure_call)]
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
#[inline(always)]
fn bf16($a: bf16) -> bf16 {
$e
}
#[inline(always)]
fn f16($a: f16) -> f16 {
$e
}
#[inline(always)]
fn f32($a: f32) -> f32 {
$e
}
#[inline(always)]
fn f64($a: f64) -> f64 {
$e
}
#[inline(always)]
fn u8(_: u8) -> u8 {
todo!("no unary function for u8")
}
#[inline(always)]
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
}
};
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
#[inline(always)]
fn bf16($a: bf16) -> bf16 {
$e
}
#[inline(always)]
fn f16($a: f16) -> f16 {
$e
}
#[inline(always)]
fn f32($a: f32) -> f32 {
$e
}
#[inline(always)]
fn f64($a: f64) -> f64 {
$e
}
#[inline(always)]
fn u8(_: u8) -> u8 {
todo!("no unary function for u8")
}
#[inline(always)]
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
const F32_VEC: bool = true;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
const F64_VEC: bool = true;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs, ys)
}
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::$f32_vec(xs, ys)
}
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::$f64_vec(xs, ys)
}
}
};
}
unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
// Hardcode the value for sqrt(2/pi)
// https://github.com/huggingface/candle/issues/1982
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
* (bf16::ONE
+ bf16::tanh(
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f32_const(0.5)
* v
* (f16::ONE
+ f16::tanh(
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn f64(v: f64) -> f64 {
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "ugelu";
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
const F32_VEC: bool = true;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_gelu(xs, ys)
}
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
const F64_VEC: bool = true;
#[cfg(any(feature = "mkl", feature = "mkl-dynamic"))]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_gelu(xs, ys)
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
/// Silu operation
impl UnaryOpT for Silu {
const NAME: &'static str = "silu";
const V: Self = Silu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v / (bf16::ONE + (-v).exp())
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v / (f16::ONE + (-v).exp())
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_silu(xs, ys)
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
const V: Self = Ceil;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.ceil()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.ceil()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.ceil()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.ceil()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Floor {
const NAME: &'static str = "floor";
const KERNEL: &'static str = "ufloor";
const V: Self = Floor;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.floor()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.floor()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.floor()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.floor()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Round {
const NAME: &'static str = "round";
const KERNEL: &'static str = "uround";
const V: Self = Round;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.round()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.round()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.round()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.round()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.max(bf16::ZERO)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.max(f16::ZERO)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.max(0f32)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.max(0f64)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
/// properly checked when creating a new value
#[derive(Clone)]
pub struct BackpropOp(Option<Op>);
impl BackpropOp {
pub(crate) fn none() -> Self {
BackpropOp(None)
}
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
let op = if arg.track_op() {
Some(f(arg.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
let op = if arg1.track_op() || arg2.track_op() {
Some(f(arg1.clone(), arg2.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new3(
arg1: &Tensor,
arg2: &Tensor,
arg3: &Tensor,
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
) -> Self {
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(f(args))
} else {
None
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {
type Target = Option<Op>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl UnaryOpT for Sign {
const NAME: &'static str = "sign";
const KERNEL: &'static str = "usign";
const V: Self = Sign;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
f32::from(v > 0.) - f32::from(v < 0.)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
f64::from(v > 0.) - f64::from(v < 0.)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
u8::min(1, v)
}
#[inline(always)]
fn u32(v: u32) -> u32 {
u32::min(1, v)
}
#[inline(always)]
fn i64(v: i64) -> i64 {
(v > 0) as i64 - (v < 0) as i64
}
}
// Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
const VERBOSE: bool = false;
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
#[repr(u8)]
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum OpCode {
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
Proto = 0x80,
Global = b'c',
BinPut = b'q',
LongBinPut = b'r',
EmptyTuple = b')',
Reduce = b'R',
Mark = b'(',
BinUnicode = b'X',
BinInt = b'J',
Tuple = b't',
BinPersId = b'Q',
BinInt1 = b'K',
BinInt2 = b'M',
Tuple1 = 0x85,
Tuple2 = 0x86,
Tuple3 = 0x87,
NewTrue = 0x88,
NewFalse = 0x89,
None = b'N',
BinGet = b'h',
LongBinGet = b'j',
SetItem = b's',
SetItems = b'u',
EmptyDict = b'}',
Dict = b'd',
Build = b'b',
Stop = b'.',
NewObj = 0x81,
EmptyList = b']',
BinFloat = b'G',
Append = b'a',
Appends = b'e',
}
// Avoid using FromPrimitive so as not to drag another dependency.
impl TryFrom<u8> for OpCode {
type Error = u8;
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
match value {
0x80 => Ok(Self::Proto),
b'c' => Ok(Self::Global),
b'q' => Ok(Self::BinPut),
b'r' => Ok(Self::LongBinPut),
b')' => Ok(Self::EmptyTuple),
b'R' => Ok(Self::Reduce),
b'(' => Ok(Self::Mark),
b'X' => Ok(Self::BinUnicode),
b'J' => Ok(Self::BinInt),
b't' => Ok(Self::Tuple),
b'Q' => Ok(Self::BinPersId),
b'K' => Ok(Self::BinInt1),
b'M' => Ok(Self::BinInt2),
b'N' => Ok(Self::None),
0x85 => Ok(Self::Tuple1),
0x86 => Ok(Self::Tuple2),
0x87 => Ok(Self::Tuple3),
0x88 => Ok(Self::NewTrue),
0x89 => Ok(Self::NewFalse),
b'h' => Ok(Self::BinGet),
b'j' => Ok(Self::LongBinGet),
b's' => Ok(Self::SetItem),
b'u' => Ok(Self::SetItems),
b'}' => Ok(Self::EmptyDict),
b'd' => Ok(Self::EmptyDict),
b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj),
b']' => Ok(Self::EmptyList),
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
value => Err(value),
}
}
}
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
let mut data: Vec<u8> = Vec::with_capacity(32);
r.read_until(b'\n', &mut data)?;
data.pop();
if data.last() == Some(&b'\r') {
data.pop();
}
Ok(data)
}
#[derive(Debug, Clone, PartialEq)]
pub enum Object {
Class {
module_name: String,
class_name: String,
},
Int(i32),
Float(f64),
Unicode(String),
Bool(bool),
None,
Tuple(Vec<Object>),
List(Vec<Object>),
Mark,
Dict(Vec<(Object, Object)>),
Reduce {
callable: Box<Object>,
args: Box<Object>,
},
Build {
callable: Box<Object>,
args: Box<Object>,
},
PersistentLoad(Box<Object>),
}
type OResult<T> = std::result::Result<T, Object>;
impl Object {
pub fn unicode(self) -> OResult<String> {
match self {
Self::Unicode(t) => Ok(t),
_ => Err(self),
}
}
pub fn reduce(self) -> OResult<(Self, Self)> {
match self {
Self::Reduce { callable, args } => Ok((*callable, *args)),
_ => Err(self),
}
}
pub fn none(self) -> OResult<()> {
match self {
Self::None => Ok(()),
_ => Err(self),
}
}
pub fn persistent_load(self) -> OResult<Self> {
match self {
Self::PersistentLoad(t) => Ok(*t),
_ => Err(self),
}
}
pub fn bool(self) -> OResult<bool> {
match self {
Self::Bool(t) => Ok(t),
_ => Err(self),
}
}
pub fn int(self) -> OResult<i32> {
match self {
Self::Int(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
_ => Err(self),
}
}
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
match self {
Self::Dict(t) => Ok(t),
_ => Err(self),
}
}
pub fn class(self) -> OResult<(String, String)> {
match self {
Self::Class {
module_name,
class_name,
} => Ok((module_name, class_name)),
_ => Err(self),
}
}
pub fn into_tensor_info(
self,
name: Self,
dir_name: &std::path::Path,
) -> Result<Option<TensorInfo>> {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => return Ok(None),
};
let (callable, args) = match self.reduce() {
Ok(callable_args) => callable_args,
_ => return Ok(None),
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
storage_size,
}))
}
}
impl TryFrom<Object> for String {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Unicode(s) => Ok(s),
other => Err(other),
}
}
}
impl TryFrom<Object> for usize {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Int(s) if s >= 0 => Ok(s as usize),
other => Err(other),
}
}
}
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Tuple(values) => {
// This does not return the appropriate value in the error case but instead return
// the object related to the first error.
values
.into_iter()
.map(|v| T::try_from(v))
.collect::<std::result::Result<Vec<T>, Self::Error>>()
}
other => Err(other),
}
}
}
#[derive(Debug)]
pub struct Stack {
stack: Vec<Object>,
memo: HashMap<u32, Object>,
}
impl Stack {
pub fn empty() -> Self {
Self {
stack: Vec::with_capacity(512),
memo: HashMap::new(),
}
}
pub fn stack(&self) -> &[Object] {
self.stack.as_slice()
}
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
loop {
if self.read(r)? {
break;
}
}
Ok(())
}
pub fn finalize(mut self) -> Result<Object> {
self.pop()
}
fn push(&mut self, obj: Object) {
self.stack.push(obj)
}
fn pop(&mut self) -> Result<Object> {
match self.stack.pop() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
fn build(&mut self) -> Result<()> {
let args = self.pop()?;
let obj = self.pop()?;
let obj = match (obj, args) {
(Object::Dict(mut obj), Object::Dict(mut args)) => {
obj.append(&mut args);
Object::Dict(obj)
}
(obj, args) => Object::Build {
callable: Box::new(obj),
args: Box::new(args),
},
};
self.push(obj);
Ok(())
}
fn reduce(&mut self) -> Result<()> {
let args = self.pop()?;
let callable = self.pop()?;
#[allow(clippy::single_match)]
let reduced = match &callable {
Object::Class {
module_name,
class_name,
} => {
if module_name == "collections"
&& (class_name == "OrderedDict" || class_name == "defaultdict")
{
// TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![]))
} else {
None
}
}
_ => None,
};
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
callable: Box::new(callable),
args: Box::new(args),
});
self.push(reduced);
Ok(())
}
fn last(&mut self) -> Result<&mut Object> {
match self.stack.last_mut() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
fn memo_get(&self, id: u32) -> Result<Object> {
match self.memo.get(&id) {
None => crate::bail!("missing object in memo {id}"),
Some(obj) => {
// Maybe we should use refcounting rather than doing potential large clones here.
Ok(obj.clone())
}
}
}
fn memo_put(&mut self, id: u32) -> Result<()> {
let obj = self.last()?.clone();
self.memo.insert(id, obj);
Ok(())
}
fn persistent_load(&self, id: Object) -> Result<Object> {
Ok(Object::PersistentLoad(Box::new(id)))
}
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
Ok(Object::Reduce {
callable: Box::new(class),
args: Box::new(args),
})
}
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
let mut mark_idx = None;
for (idx, obj) in self.stack.iter().enumerate().rev() {
if obj == &Object::Mark {
mark_idx = Some(idx);
break;
}
}
match mark_idx {
Some(mark_idx) => {
let objs = self.stack.split_off(mark_idx + 1);
self.stack.pop();
Ok(objs)
}
None => {
crate::bail!("marker object not found")
}
}
}
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
let op_code = match OpCode::try_from(r.read_u8()?) {
Ok(op_code) => op_code,
Err(op_code) => {
crate::bail!("unknown op-code {op_code}")
}
};
// println!("op: {op_code:?}");
// println!("{:?}", self.stack);
match op_code {
OpCode::Proto => {
let version = r.read_u8()?;
if VERBOSE {
println!("proto {version}");
}
}
OpCode::Global => {
let module_name = read_to_newline(r)?;
let class_name = read_to_newline(r)?;
let module_name = String::from_utf8_lossy(&module_name).to_string();
let class_name = String::from_utf8_lossy(&class_name).to_string();
self.push(Object::Class {
module_name,
class_name,
})
}
OpCode::BinInt1 => {
let arg = r.read_u8()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt2 => {
let arg = r.read_u16::<LittleEndian>()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt => {
let arg = r.read_i32::<LittleEndian>()?;
self.push(Object::Int(arg))
}
OpCode::BinFloat => {
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
let arg = r.read_f64::<byteorder::BigEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => {
let len = r.read_u32::<LittleEndian>()?;
let mut data = vec![0u8; len as usize];
r.read_exact(&mut data)?;
let data = String::from_utf8(data).map_err(E::wrap)?;
self.push(Object::Unicode(data))
}
OpCode::BinPersId => {
let id = self.pop()?;
let obj = self.persistent_load(id)?;
self.push(obj)
}
OpCode::Tuple => {
let objs = self.pop_to_marker()?;
self.push(Object::Tuple(objs))
}
OpCode::Tuple1 => {
let obj = self.pop()?;
self.push(Object::Tuple(vec![obj]))
}
OpCode::Tuple2 => {
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2]))
}
OpCode::Tuple3 => {
let obj3 = self.pop()?;
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
}
OpCode::NewTrue => self.push(Object::Bool(true)),
OpCode::NewFalse => self.push(Object::Bool(false)),
OpCode::Append => {
let value = self.pop()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.push(value)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::Appends => {
let objs = self.pop_to_marker()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.extend(objs)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::SetItem => {
let value = self.pop()?;
let key = self.pop()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
d.push((key, value))
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::SetItems => {
let mut objs = self.pop_to_marker()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
d.push((key, value))
}
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::None => self.push(Object::None),
OpCode::Stop => {
return Ok(true);
}
OpCode::Build => self.build()?,
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
OpCode::Dict => {
let mut objs = self.pop_to_marker()?;
let mut pydict = vec![];
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
}
OpCode::Mark => self.push(Object::Mark),
OpCode::Reduce => self.reduce()?,
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
OpCode::EmptyList => self.push(Object::List(vec![])),
OpCode::BinGet => {
let arg = r.read_u8()?;
let obj = self.memo_get(arg as u32)?;
self.push(obj)
}
OpCode::LongBinGet => {
let arg = r.read_u32::<LittleEndian>()?;
let obj = self.memo_get(arg)?;
self.push(obj)
}
OpCode::BinPut => {
let arg = r.read_u8()?;
self.memo_put(arg as u32)?
}
OpCode::LongBinPut => {
let arg = r.read_u32::<LittleEndian>()?;
self.memo_put(arg)?
}
OpCode::NewObj => {
let args = self.pop()?;
let class = self.pop()?;
let obj = self.new_obj(class, args)?;
self.push(obj)
}
}
Ok(false)
}
}
impl From<Object> for E {
fn from(value: Object) -> Self {
E::Msg(format!("conversion error on {value:?}"))
}
}
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
"FloatStorage" => DType::F32,
"DoubleStorage" => DType::F64,
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
Ok((layout, dtype, path, storage_size))
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub dtype: DType,
pub layout: Layout,
pub path: String,
pub storage_size: usize,
}
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let zip_file_names = zip
.file_names()
.map(|f| f.to_string())
.collect::<Vec<String>>();
let mut tensor_infos = vec![];
for file_name in zip_file_names.iter() {
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE || verbose {
println!("{obj:#?}");
}
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
Object::Class {
module_name,
class_name,
} if module_name == "__torch__" && class_name == "Module" => *args,
_ => continue,
},
_ => continue,
},
obj => obj,
};
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) {
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
Ok(None) => {}
Err(err) => eprintln!("skipping: {err:?}"),
}
}
}
}
Ok(tensor_infos)
}
/// Lazy tensor loader.
pub struct PthTensors {
tensor_infos: HashMap<String, TensorInfo>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
.collect();
let path = path.as_ref().to_owned();
Ok(Self { tensor_infos, path })
}
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
&self.tensor_infos
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
};
// We hope that the file has not changed since first reading it.
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic
// case and when the tensor is fortran contiguous.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor))
}
}
}
/// Read all the tensors from a PyTorch pth file with a given key.
///
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[inline(always)]
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
let ones = _mm256_set1_epi16(1);
let summed_pairs = _mm256_madd_epi16(ones, x);
_mm256_cvtepi32_ps(summed_pairs)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
let dot = _mm256_maddubs_epi16(ax, sy);
sum_i16_pairs_float(dot)
}
#[inline(always)]
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
let res = _mm256_extractf128_ps(x, 1);
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res)
}
#[inline(always)]
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
let tmp = _mm_loadu_si128(rsi as *const __m128i);
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
let low_mask = _mm256_set1_epi8(0xF);
_mm256_and_si256(low_mask, bytes)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
mul_sum_us8_pairs_float(ax, sy)
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
let off = _mm256_set1_epi8(8);
let bx = _mm256_sub_epi8(bx, off);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
const K_SHUFFLE: [u8; 128] = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
];
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 256] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 128] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
}
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let m2 = _mm256_set1_epi8(3);
let m32s = _mm256_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q4 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 128 {
let is = j * 4;
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
qh = qh.add(32);
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
let q4h_1 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
let q4h_2 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
let q4h_3 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
let q4_2 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
let q4_3 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let m3 = _mm256_set1_epi8(3);
let m4 = _mm_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let scales8 = _mm_and_si128(mins_and_scales, m4);
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
let mins = _mm256_cvtepi8_epi16(mins8);
let prod =
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
let all_scales = _mm256_cvtepi8_epi16(scales8);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
let mut sumi = _mm256_setzero_si256();
for scale in scales {
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
q2 = q2.add(32);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q2_0 = _mm256_and_si256(q2bits, m3);
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
let p0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
let p1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
let p2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
let p3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
let p0 = _mm256_add_epi32(p0, p1);
let p2 = _mm256_add_epi32(p2, p3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
let mut aux = [0u32; 3];
unsafe {
let m3 = _mm256_set1_epi8(3);
let mone = _mm256_set1_epi8(1);
let m32 = _mm_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
LittleEndian::read_u32_into(&x.scales, &mut aux);
let scales128 = _mm_set_epi32(
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
);
let scales128 = _mm_sub_epi8(scales128, m32);
let all_scales = _mm256_cvtepi8_epi16(scales128);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
// high bit
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
let mut sumi = _mm256_setzero_si256();
for (j, scale) in scales.iter().enumerate() {
// load low 2 bits
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
q3 = q3.add(32);
// Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a separate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
};
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
let q3h_1 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
};
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
let q3h_2 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
};
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
let q3h_3 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
};
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
// load Q8 quants
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
// high bit was set)
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
// multiply with scales
let p16_0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
let p16_1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
let p16_2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
let p16_3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
// accumulate
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
}
// multiply with block scale and accumulate
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
let mut acc_m = _mm_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4l = _mm256_and_si256(q4bits, m4);
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16l = _mm256_maddubs_epi16(q4l, q8l);
let p16l = _mm256_madd_epi16(scale_l, p16l);
sumi = _mm256_add_epi32(sumi, p16l);
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16h = _mm256_maddubs_epi16(q4h, q8h);
let p16h = _mm256_madd_epi16(scale_h, p16h);
sumi = _mm256_add_epi32(sumi, p16h);
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mzero = _mm_setzero_si128();
let mone = _mm256_set1_epi8(1);
let mut acc = _mm256_setzero_ps();
let mut summs = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
let mut hmask = mone;
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
_ => unreachable!(),
};
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
hmask = _mm256_slli_epi16(hmask, 1);
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_1_right_shift = match j {
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
_ => unreachable!(),
};
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
hmask = _mm256_slli_epi16(hmask, 1);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc) + summs)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (xs, ys) in xs.iter().zip(ys.iter()) {
let mut sumi = _mm256_setzero_si256();
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
for j in (0..QK_K).step_by(32) {
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
}
let d = _mm256_set1_ps(xs.d * ys.d);
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn ceil_div(p: usize, q: usize) -> usize {
(p + q - 1) / q
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
}
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
//let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func_bin("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn dequantize(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
//let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func_bin(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
//let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func_bin(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
shared_mem_bytes: 0,
};
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
//let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func_bin(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nrows as u32, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
let vec = slice.to_vec();
T::to_float(&vec, dst)
}
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
pub fn fwd(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
}
}
}
impl QCudaStorage {
fn dequantize_matmul_vec(
&self,
self_shape: &crate::Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
};
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
vec![1, nrows]
};
Ok((out, out_shape.into()))
}
fn dequantize_matmul(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage;
let (n, k) = self_shape.dims2()?;
let (b, m, k2) = match layout.shape().dims() {
&[b, m, k2] => (b, m, k2),
&[m, k2] => (1, m, k2),
s => crate::bail!("unexpected shape for input {s:?}"),
};
if k2 != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
Ok((out, out_shape.into()))
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
) -> Result<super::QStorage> {
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
device: device.clone(),
dtype: T::DTYPE,
}))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn cuda_quantize_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let el = 256;
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
Ok(())
}
#[test]
fn cuda_mmv_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
assert_eq!(vs[0], 5561664.5);
let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
}
}
#![allow(unused)]
use super::GgmlDType;
use crate::{CudaDevice, CudaStorage, Error, Result};
pub struct QCudaStorage {
dtype: GgmlDType,
device: CudaDevice,
}
impl QCudaStorage {
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &CudaStorage,
_layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &CudaDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &MetalDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
//! Support for the GGML file format.
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Ggjt,
Ggla,
Ggmf,
Ggml,
Ggsn,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x67676a74 => Self::Ggjt,
0x67676c61 => Self::Ggla,
0x67676d66 => Self::Ggmf,
0x67676d6c => Self::Ggml,
0x6767736e => Self::Ggsn,
_ => crate::bail!("unknown magic {value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgmlUnversioned,
GgmfV1,
GgjtV1,
GgjtV2,
GgjtV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
if magic == Magic::Ggml {
return Ok(Self::GgmlUnversioned);
}
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Ggmf, 1) => Self::GgmfV1,
(Magic::Ggjt, 1) => Self::GgjtV1,
(Magic::Ggjt, 2) => Self::GgjtV2,
(Magic::Ggjt, 3) => Self::GgjtV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
fn align32(&self) -> bool {
match self {
Self::GgmlUnversioned | Self::GgmfV1 => false,
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HParams {
pub n_vocab: u32,
pub n_embd: u32,
pub n_mult: u32,
pub n_head: u32,
pub n_layer: u32,
pub n_rot: u32,
pub ftype: u32,
}
impl HParams {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let n_vocab = reader.read_u32::<LittleEndian>()?;
let n_embd = reader.read_u32::<LittleEndian>()?;
let n_mult = reader.read_u32::<LittleEndian>()?;
let n_head = reader.read_u32::<LittleEndian>()?;
let n_layer = reader.read_u32::<LittleEndian>()?;
let n_rot = reader.read_u32::<LittleEndian>()?;
let ftype = reader.read_u32::<LittleEndian>()?;
Ok(Self {
n_vocab,
n_embd,
n_mult,
n_head,
n_layer,
n_rot,
ftype,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Vocab {
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
}
impl Vocab {
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
let mut token_score_pairs = Vec::with_capacity(n_vocab);
for _index in 0..n_vocab {
let len = reader.read_u32::<LittleEndian>()? as usize;
let mut word = vec![0u8; len];
reader.read_exact(&mut word)?;
let score = reader.read_f32::<LittleEndian>()?;
token_score_pairs.push((word, score))
}
Ok(Self { token_score_pairs })
}
}
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
};
super::QTensor::new(data, dims)
}
/// Creates a [Tensor] from a raw GGML tensor.
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
// The dimensions are stored in reverse order, see for example:
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
dims.reverse();
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
if magic.align32() {
let pos = reader.stream_position()?;
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
}
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
let device = device.clone();
Ok(Self {
magic,
hparams,
vocab,
tensors,
device,
})
}
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
match self.tensors.remove(name) {
None => crate::bail!("cannot find tensor with name '{name}'"),
Some(tensor) => Ok(tensor),
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment