lib.rs 16 KB
Newer Older
yongshk's avatar
yongshk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
mod ffi;

use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, DType, Layout, Result, Shape, Storage, Tensor};
use half::{bf16, f16};
use std::ptr;

fn layer_norm_internal_type(dtype: DType) -> Result<u32> {
    let internal_type = match dtype {
        DType::F16 => 0,
        DType::BF16 => 1,
        DType::F32 => 2,
        dtype => candle::bail!("dtype {dtype:?} is not supported"),
    };
    Ok(internal_type)
}

pub struct LayerNorm {
    pub epsilon: f32,
    pub is_rms_norm: bool,
    pub gamma: Tensor,
    pub beta: Option<Tensor>,
}

fn round_multiple(x: usize, m: usize) -> usize {
    (x + m - 1) / m * m
}

impl LayerNorm {
    fn fwd<
        T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
    >(
        &self,
        x: &candle::CudaStorage,
        x_l: &Layout,
        r: Option<&candle::CudaStorage>,
        r_l: Option<&Layout>,
    ) -> Result<(candle::CudaStorage, Shape)> {
        // Assume all tensors are on the same device and take device of x
        let dev = x.device();

        // Get internal layer norm type id for the given dtype
        let layer_norm_type = layer_norm_internal_type(x.dtype())?;

        // Make sure that gamma is a CUDA tensor and get the underlying storage
        let (g, g_l) = self.gamma.storage_and_layout();
        let g = match &*g {
            Storage::Cuda(g) => g,
            _ => candle::bail!("gamma must be a cuda tensor"),
        };

        // Get cuda slices for all tensors
        let x = x.as_cuda_slice::<T>()?;
        let g = g.as_cuda_slice::<T>()?;

        // Get cuda views for all tensors
        let x = x.slice(x_l.start_offset()..);
        let g = g.slice(g_l.start_offset()..);

        // Input matrix layout
        let rows = x_l.dims()[0];
        let cols = x_l.dims()[1];

        if !(cols % 8 == 0 && cols <= 8192) {
            candle::bail!("hidden size must be % 8 and <= 8192")
        }

        let x_stride = x_l.stride();
        let g_stride = g_l.stride();

        let x_rank = x_stride.len();
        let g_rank = g_stride.len();

        if x_rank != 2 {
            candle::bail!("layer-norm expects input tensors of rank 2. Found: {x_rank}")
        }
        if x_stride[x_rank - 1] != 1 {
            candle::bail!("the last dim of x must be contiguous {x_stride:?}")
        }
        if g_stride[g_rank - 1] != 1 {
            candle::bail!("the last dim of g must be contiguous {g_stride:?}")
        }

        // Round cols to match with the correct kernel
        let cols_rounded = if cols <= 1536 {
            round_multiple(cols, 256)
        } else if cols <= 3072 {
            round_multiple(cols, 512)
        } else {
            round_multiple(cols, 1024)
        };

        let is_rms_norm = if self.is_rms_norm { 1 } else { 0 };

        // If beta is et, get ids device pointer
        let b_ptr = if let Some(beta) = &self.beta {
            // Make sure that beta is a CUDA tensor and get the underlying storage
            let (b, b_l) = beta.storage_and_layout();
            let b = match &*b {
                Storage::Cuda(b) => b,
                _ => candle::bail!("gamma must be a cuda tensor"),
            };

            let b = b.as_cuda_slice::<T>()?;
            let b = b.slice(b_l.start_offset()..);

            let b_stride = b_l.stride();
            let b_rank = b_stride.len();

            if b_stride[b_rank - 1] != 1 {
                candle::bail!("the last dim of b must be contiguous {b_stride:?}")
            }
            *b.device_ptr() as *const core::ffi::c_void
        } else {
            ptr::null() as *const std::ffi::c_void
        };

        // If residual is set, get its device pointer
        let r_ptr = if let (Some(r), Some(r_l)) = (r, r_l) {
            // Check shape
            let expected_shape = x_l.shape().dims2()?;
            if r_l.shape().dims2()? != expected_shape {
                candle::bail!("shape mismatch x {:?} and r {:?}", x_l.shape(), r_l.shape());
            }

            let r = r.as_cuda_slice::<T>()?;
            let r = r.slice(r_l.start_offset()..);

            let r_stride = r_l.stride();
            let r_rank = r_stride.len();

            if r_rank != 2 {
                candle::bail!("layer-norm expects input tensors of rank 2. Found: {r_rank}")
            }

            if r_stride[r_rank - 1] != 1 {
                candle::bail!("the last dim of r must be contiguous {r_stride:?}")
            }
            *r.device_ptr() as *const std::ffi::c_void
        } else {
            ptr::null() as *const std::ffi::c_void
        };

        // We will store the results of the residual add next to the main results
        // so out has the same shape as inp * 2
        let out_shape = Shape::from((rows * 2, cols));

        let out = unsafe { dev.alloc::<T>(out_shape.elem_count()) }.w()?;
        let dst = out.slice(..rows * cols);
        let dst_add = out.slice(rows * cols..);

        // Alloc internal buffers
        let mu = unsafe { dev.alloc::<f32>(rows) }.w()?;
        let rsigma = unsafe { dev.alloc::<f32>(rows) }.w()?;

        // Get cuda device pointers from cuda slices
        let x_ptr = *x.device_ptr() as *const core::ffi::c_void;
        let g_ptr = *g.device_ptr() as *const core::ffi::c_void;
        let dst_add_ptr = *dst_add.device_ptr() as *const core::ffi::c_void;
        let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
        let mu_ptr = *mu.device_ptr() as *const core::ffi::c_void;
        let rsigma_ptr = *rsigma.device_ptr() as *const core::ffi::c_void;

        let multi_processors_count = dev
            .attribute(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
            .unwrap();

        unsafe {
            // Launch Kernel
            ffi::run_ln(
                x_ptr,
                r_ptr,
                g_ptr,
                b_ptr,
                dst_add_ptr,
                dst_ptr,
                mu_ptr,
                rsigma_ptr,
                self.epsilon,
                cols_rounded as u32,
                rows as u32,
                cols as u32,
                multi_processors_count,
                layer_norm_type,
                layer_norm_type,
                layer_norm_type,
                layer_norm_type,
                2,
                is_rms_norm,
            )
        }

        let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone());

        Ok((out, out_shape))
    }
}

impl candle::CustomOp1 for LayerNorm {
    fn name(&self) -> &'static str {
        "fused-layer-norm"
    }

    fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> {
        candle::bail!("no cpu support for fused-layer-norm")
    }

    fn cuda_fwd(
        &self,
        x: &candle::CudaStorage,
        x_l: &Layout,
    ) -> Result<(candle::CudaStorage, Shape)> {
        match x.dtype() {
            DType::F16 => self.fwd::<f16>(x, x_l, None, None),
            DType::BF16 => self.fwd::<bf16>(x, x_l, None, None),
            DType::F32 => self.fwd::<f32>(x, x_l, None, None),
            dt => {
                candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})")
            }
        }
    }
}

impl candle::CustomOp2 for LayerNorm {
    fn name(&self) -> &'static str {
        "fused-layer-norm"
    }

    fn cpu_fwd(
        &self,
        _: &CpuStorage,
        _: &Layout,
        _: &CpuStorage,
        _: &Layout,
    ) -> Result<(CpuStorage, Shape)> {
        candle::bail!("no cpu support for fused-layer-norm")
    }

    fn cuda_fwd(
        &self,
        x: &candle::CudaStorage,
        x_l: &Layout,
        r: &candle::CudaStorage,
        r_l: &Layout,
    ) -> Result<(candle::CudaStorage, Shape)> {
        match x.dtype() {
            DType::F16 => self.fwd::<f16>(x, x_l, Some(r), Some(r_l)),
            DType::BF16 => self.fwd::<bf16>(x, x_l, Some(r), Some(r_l)),
            DType::F32 => self.fwd::<f32>(x, x_l, Some(r), Some(r_l)),
            dt => {
                candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})")
            }
        }
    }
}

/// Layer Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensor has the same dimensions as `x`
pub fn layer_norm(
    x: &Tensor,
    gamma: &Tensor,
    beta: Option<&Tensor>,
    epsilon: f32,
) -> Result<Tensor> {
    let op = LayerNorm {
        epsilon,
        gamma: gamma.clone(),
        beta: beta.cloned(),
        is_rms_norm: false,
    };
    let results = x.apply_op1(op)?;
    let rows = x.dims()[0];
    results.narrow(0, 0, rows)
}

/// Fused Add Layer Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `res` - Residual tensor of rank 2. Will be added to `x` before normalization. Must have
/// the same shape as `x`.
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensors have the same dimensions as `x`
/// First tensor is the result of the normalization, second is the result of the residual add
pub fn fused_add_layer_norm(
    x: &Tensor,
    res: &Tensor,
    gamma: &Tensor,
    beta: Option<&Tensor>,
    epsilon: f32,
) -> Result<(Tensor, Tensor)> {
    let op = LayerNorm {
        epsilon,
        gamma: gamma.clone(),
        beta: beta.cloned(),
        is_rms_norm: false,
    };
    let results = x.apply_op2(&res, op)?;
    let rows = x.dims()[0];
    Ok((results.narrow(0, 0, rows)?, results.narrow(0, rows, rows)?))
}

/// Layer RMS Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensor has the same dimensions as `x`
pub fn rms_norm(x: &Tensor, gamma: &Tensor, beta: Option<&Tensor>, epsilon: f32) -> Result<Tensor> {
    let op = LayerNorm {
        epsilon,
        gamma: gamma.clone(),
        beta: beta.cloned(),
        is_rms_norm: true,
    };
    let results = x.apply_op1(op)?;
    let rows = x.dims()[0];
    results.narrow(0, 0, rows)
}

/// Fused Add RMS Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `res` - Residual tensor of rank 2. Will be added to `x` before normalization. Must have
/// the same shape as `x`.
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensors have the same dimensions as `x`
/// First tensor is the result of the normalization, second is the result of the residual add
pub fn fused_add_rms_norm(
    x: &Tensor,
    res: &Tensor,
    gamma: &Tensor,
    beta: Option<&Tensor>,
    epsilon: f32,
) -> Result<(Tensor, Tensor)> {
    let op = LayerNorm {
        epsilon,
        gamma: gamma.clone(),
        beta: beta.cloned(),
        is_rms_norm: true,
    };
    let results = x.apply_op2(&res, op)?;
    let rows = x.dims()[0];
    Ok((results.narrow(0, 0, rows)?, results.narrow(0, rows, rows)?))
}

#[cfg(test)]
mod tests {
    use super::*;
    use candle::{DType, Device};

    fn layer_norm_truth(
        x: &Tensor,
        gamma: &Tensor,
        beta: Option<&Tensor>,
        epsilon: f64,
        rms: bool,
    ) -> Result<Tensor> {
        let x_dtype = x.dtype();
        let internal_dtype = match x_dtype {
            DType::F16 | DType::BF16 => DType::F32,
            d => d,
        };

        let (_seq_len, hidden_size) = x.shape().dims2()?;
        let x = x.to_dtype(internal_dtype)?;

        let x = if !rms {
            let mean_x = (x.sum_keepdim(1)? / hidden_size as f64)?;
            x.broadcast_sub(&mean_x)?
        } else {
            x
        };

        let norm_x = (x.sqr()?.sum_keepdim(1)? / hidden_size as f64)?;
        let x_normed = x.broadcast_div(&(norm_x + epsilon)?.sqrt()?)?;

        let mut x = x_normed.to_dtype(x_dtype)?.broadcast_mul(gamma)?;
        if let Some(beta) = beta {
            x = x.broadcast_add(beta)?;
        }
        Ok(x)
    }

    fn to_vec2_round(t: Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
        let b = 10f32.powi(digits);
        let t = t.to_dtype(DType::F32)?.to_vec2::<f32>()?;
        let t = t
            .iter()
            .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
            .collect();
        Ok(t)
    }

    #[test]
    fn test_layer_norm() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let res = layer_norm(&x, &g, Some(&b), 1e-12)?;
        let truth = layer_norm_truth(&x, &g, Some(&b), 1e-12, false)?;

        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }

    #[test]
    fn test_layer_norm_no_bias() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let res = layer_norm(&x, &g, None, 1e-12)?;
        let truth = layer_norm_truth(&x, &g, None, 1e-12, false)?;

        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }

    #[test]
    fn test_rms_norm() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let res = rms_norm(&x, &g, Some(&b), 1e-12)?;
        let truth = layer_norm_truth(&x, &g, Some(&b), 1e-12, true)?;
        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }

    #[test]
    fn test_rms_norm_no_bias() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let res = rms_norm(&x, &g, None, 1e-12)?;
        let truth = layer_norm_truth(&x, &g, None, 1e-12, true)?;

        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }

    #[test]
    fn test_layer_norm_add() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let (res, res_add) = fused_add_layer_norm(&x, &r, &g, Some(&b), 1e-12)?;
        let truth_add = (x + r)?;
        let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, false)?;
        assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?);
        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }

    #[test]
    fn test_rms_norm_add() -> Result<()> {
        let device = Device::new_cuda(0)?;

        let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
        let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
        let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;

        let (res, res_add) = fused_add_rms_norm(&x, &r, &g, Some(&b), 1e-12)?;
        let truth_add = (x + r)?;
        let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, true)?;
        assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?);
        assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
        Ok(())
    }
}