lib.rs 5.46 KB
Newer Older
yongshk's avatar
yongshk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
mod ffi;

use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::{DType, Device, Result, Storage, Tensor};
use half::{bf16, f16};
use std::ffi::{c_int, c_long};

fn apply_rotary_<
    T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
    query: &Tensor,
    key: &Tensor,
    cos_cache: &Tensor,
    sin_cache: &Tensor,
    is_neox: bool,
) -> Result<()> {
    let dtype = query.dtype();
    if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
        candle::bail!("apply-rotary expects all tensors to have the same dtype");
    }

    let internal_type = match dtype {
        DType::F16 => 0,
        DType::BF16 => 1,
        DType::F32 => 2,
        dtype => candle::bail!("dtype {dtype:?} is not supported"),
    };

    let (q, q_l) = query.storage_and_layout();
    let q = match &*q {
        Storage::Cuda(q) => q,
        _ => candle::bail!("query must be a cuda tensor"),
    };

    let (k, k_l) = key.storage_and_layout();
    let k = match &*k {
        Storage::Cuda(k) => k,
        _ => candle::bail!("key must be a cuda tensor"),
    };

    let (cc, cc_l) = cos_cache.storage_and_layout();
    let cc = match &*cc {
        Storage::Cuda(cc) => cc,
        _ => candle::bail!("cos_cache must be a cuda tensor"),
    };

    let (sc, sc_l) = sin_cache.storage_and_layout();
    let sc = match &*sc {
        Storage::Cuda(sc) => sc,
        _ => candle::bail!("sin_cache must be a cuda tensor"),
    };

    let q_rank = q_l.stride().len();
    let k_rank = k_l.stride().len();
    let cc_rank = cc_l.stride().len();
    let sc_rank = sc_l.stride().len();

    if q_rank != 3 || k_rank != 3 {
        candle::bail!("apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})")
    }

    if cc_rank != 2 || sc_rank != 2 {
        candle::bail!("apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})")
    }

    // Get cuda slices for all tensors
    let q = q.as_cuda_slice::<T>()?;
    let k = k.as_cuda_slice::<T>()?;
    let cc = cc.as_cuda_slice::<T>()?;
    let sc = sc.as_cuda_slice::<T>()?;

    // Get cuda views for all tensors
    let q = q.slice(q_l.start_offset()..);
    let k = k.slice(k_l.start_offset()..);
    let cc = cc.slice(cc_l.start_offset()..);
    let sc = sc.slice(sc_l.start_offset()..);

    let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
    let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;

    if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
        candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
    }

    let rot_dim = cc_l.dims()[1];
    if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
        candle::bail!(
            "shape mismatch cos_cache {:?}, expected {:?}",
            cc_l.shape(),
            (num_tokens, rot_dim)
        )
    }

    if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
        candle::bail!(
            "shape mismatch sin_cache {:?}, expected {:?}",
            sc_l.shape(),
            (num_tokens, rot_dim)
        )
    }

    let query_stride = q_l.stride()[0];
    let key_stride = k_l.stride()[0];

    let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
    let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
    let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void;
    let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void;

    let neox = if is_neox { 1 } else { 0 };

    unsafe {
        ffi::rotary_embedding(
            q_ptr,
            k_ptr,
            cc_ptr,
            sc_ptr,
            neox,
            head_size as c_int,
            num_tokens as c_long,
            rot_dim as c_int,
            num_heads as c_int,
            num_kv_heads as c_int,
            query_stride as c_long,
            key_stride as c_long,
            internal_type,
        )
    }
    Ok(())
}

pub fn inv_freqs(dim: usize, base: f32, device: &Device) -> Result<Tensor> {
    let inv_freq: Vec<_> = (0..dim)
        .step_by(2)
        .map(|i| 1f32 / base.powf(i as f32 / dim as f32))
        .collect();
    let inv_freq_len = inv_freq.len();
    Tensor::from_vec(inv_freq, (1, inv_freq_len), device)
}

pub fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
    let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
        .to_dtype(DType::F32)?
        .reshape((length, 1))?;
    let freqs = t.matmul(&inv_freqs)?;
    let cos = freqs.cos()?.to_dtype(dtype)?;
    let sin = freqs.sin()?.to_dtype(dtype)?;
    Ok((cos, sin))
}

/// Apply Rotary position encoding inplace
///
/// # Arguments
///
/// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`.
/// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`.
/// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
/// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
/// * `is_neox` - Use neox encoding instead of gpt-j style rotary
pub fn apply_rotary_inplace(
    query: &Tensor,
    key: &Tensor,
    cos_cache: &Tensor,
    sin_cache: &Tensor,
    is_neox: bool,
) -> Result<()> {
    match key.dtype() {
        DType::F16 => apply_rotary_::<f16>(query, key, cos_cache, sin_cache, is_neox),
        DType::BF16 => apply_rotary_::<bf16>(query, key, cos_cache, sin_cache, is_neox),
        DType::F32 => apply_rotary_::<f32>(query, key, cos_cache, sin_cache, is_neox),
        dt => {
            candle::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})")
        }
    }
}