Commit 790211d2 authored by yongshk's avatar yongshk
Browse files

Initial commit

parents
.idea
target
Cargo.lock
[package]
name = "candle-layer-norm"
version = "0.0.1"
edition = "2021"
description = "Layer Norm layer for the candle ML framework."
[dependencies]
candle = { version = "*", package = "candle-core", features = ["cuda"] }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
# Candle Cuda Layer Norm
Layer Norm fused operation for the Candle ML framework.
This Layer was adapted from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm.
It implements fused dropout + residual + LayerNorm, building on Apex's FastLayerNorm.
Major changes:
- Add residual.
- Make it work for both pre-norm and post-norm architecture.
- Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
- Implement RMSNorm as an option.
\ No newline at end of file
// Build script to run nvcc and generate the C glue code for launching the layer-norm kernel.
// The cuda build time is very long so one can set the CANDLE_LAYER_NORM_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 1] = ["ln_api.cu"];
fn main() -> Result<()> {
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|_| num_cpus::get_physical(),
|s| usize::from_str(&s).unwrap(),
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus)
.build_global()
.unwrap();
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
}
println!("cargo:rerun-if-changed=kernels/**.cu");
println!("cargo:rerun-if-changed=kernels/ln_fwd_kernels.cuh");
println!("cargo:rerun-if-changed=kernels/ln_kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/ln_utils.cuh");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_LAYER_NORM_BUILD_DIR") {
Err(_) =>
{
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
Ok(build_dir) => {
let path = PathBuf::from(build_dir);
path.canonicalize().expect(&format!(
"Directory doesn't exists: {} (the current directory is {})",
&path.display(),
std::env::current_dir()?.display()
))
}
};
set_cuda_include_dir()?;
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let compute_cap = compute_cap()?;
let out_file = build_dir.join("liblayernorm.a");
let kernel_dir = PathBuf::from("kernels");
let cu_files: Vec<_> = KERNEL_FILES
.iter()
.map(|f| {
let mut obj_file = out_dir.join(f);
obj_file.set_extension("o");
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
} else {
true
};
if should_compile {
cu_files
.par_iter()
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_BFLOAT16_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("-U__CUDA_NO_BFLOAT162_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT162_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose")
.arg("-fPIC")
.arg("-Wno-error=return-type");
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
Ok(())
})
.collect::<Result<()>>()?;
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
let mut command = std::process::Command::new("nvcc");
command
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=layernorm");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
Ok(())
}
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().unwrap();
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute cap
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <iostream>
//#ifdef OLD_GENERATOR_PATH
//#include <ATen/CUDAGeneratorImpl.h>
//#else
//#include <ATen/cuda/CUDAGeneratorImpl.h>
//#endif
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t elts_per_thread;
size_t workspace_bytes;
size_t barrier_size;
int multi_processor_count;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, gamma1(nullptr)
, rowscale(nullptr)
, colscale(nullptr)
, dropout_keep_p(1.f)
, dropout_scale(1.f)
, is_rms_norm(false)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x0;
void *x1;
void *residual;
void *x;
void *dmask;
void *dmask1;
void *mu;
void *rs;
void *gamma;
void *gamma1;
void *rowscale;
void *colscale;
void *x0_subset;
void *z_subset;
float inverse_cols;
float dropout_keep_p;
float dropout_scale;
float rowscale_const;
bool is_rms_norm;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, z1(nullptr)
, beta(nullptr)
, beta1(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *z1;
void *beta;
void *beta1;
float epsilon;
// Random state.
// at::PhiloxCudaState philox_args;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
extern FwdRegistry FWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct ResidualType2Key : public Type2Key<T, 4>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 6>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 8>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
#include "ln.h"
#include "ln_fwd_kernels.cuh"
#include <iostream>
/*
Ada
Supported Type combinations:
input residual compute weights output
============================================
fp32 fp32 fp32 fp32 fp32
fp16 fp32 fp32 fp32 fp16
fp16 fp16 fp32 fp32 fp16
bf16 fp32 fp32 fp32 bf16
bf16 bf16 fp32 fp32 bf16
fp16 fp16 fp32 fp16 fp16
bf16 bf16 fp32 bf16 bf16
Remarks:
Output type = Input type
Compute always in FP32
*/
namespace layer_norm {
FwdRegistry FWD_FUNCS;
uint64_t get_key(uint32_t wtype, uint32_t itype, uint32_t rtype, uint32_t otype, uint32_t ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = wtype | (itype << 2) | (rtype << 4) | (otype << 6) | (ctype << 8);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
}
layer_norm::FwdFunction & get_fwd_launcher(uint32_t wtype, uint32_t itype, uint32_t rtype, uint32_t otype, uint32_t ctype, uint32_t hidden_size) {
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
return iter->second;
}
REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
extern "C" void run_ln(
void *x,
void *residual,
void *gamma,
void *beta,
void *dst_add,
void *dst,
void *mu,
void *rsigma,
float epsilon,
uint32_t hidden_size_rounded,
uint32_t rows,
uint32_t cols,
int32_t multi_processor_count,
uint32_t wtype,
uint32_t itype,
uint32_t rtype,
uint32_t otype,
uint32_t ctype,
int is_rms_norm
) {
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.multi_processor_count = multi_processor_count;
launch_params.stream = 0;
launch_params.params.dropout_keep_p = 1.f;
launch_params.params.residual = residual;
launch_params.params.rowscale = nullptr;
launch_params.params.colscale = nullptr;
launch_params.params.x0_subset = nullptr;
launch_params.params.z_subset = nullptr;
// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size_rounded);
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x0 = x;
params.x = dst_add;
params.dmask = nullptr;
params.mu = mu;
params.rs = rsigma;
params.gamma = gamma;
params.beta = beta;
params.z = dst;
params.epsilon = epsilon;
params.dropout_scale = 1.f;
params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = 1.f;
params.is_rms_norm = is_rms_norm;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Launch the kernel.
launcher(launch_params, false);
}
\ No newline at end of file
#pragma once
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
// #include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
#include <curand_kernel.h>
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using input_t = typename Ktraits::input_t;
using residual_t = typename Ktraits::residual_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using mask_t = typename Ktraits::mask_t;
using Ivec = typename Ktraits::Ivec;
using Rvec = typename Ktraits::Rvec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Mvec = typename Ktraits::Mvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
const bool has_residual = params.residual != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
// curandStatePhilox4_32_10_t state;
// if (Is_dropout) {
// auto seeds = at::cuda::philox::unpack(params.philox_args);
// const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
// curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
//}
const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
Wvec gamma[LDGS];
Wvec beta[LDGS];
Wvec colscale[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx);
if (params.beta != nullptr) {
beta[it].load_from(params.beta, idx);
} else {
beta[it].zero_();
}
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += VEC_COLS_PER_LDG;
}
}
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
const int row_z = !Has_subset ? row + 1 : z_subset[row];
const bool load_x0 = !Has_subset || row_x0 > 0;
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec x0;
Rvec residual;
Rvec x;
Mvec dmask;
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (has_residual) { residual.load_from(params.residual, idx_x); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
compute_t x_ij;
if (load_x0) {
//mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
mask_t keep = false;
// if (Is_dropout) { dmask.data.elt[jt] = keep; }
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
// x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
} else {
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
}
if (save_x) { x.store_to(params.x, idx_x); }
// if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx_x += VEC_COLS_PER_LDG;
idx_x0 += VEC_COLS_PER_LDG;
}
}
static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
// Need to convert to int, otherwise the subtraction will wrap around.
const index_t valid_partial_vecs_in_warp =
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
int(THREADS_PER_WARP));
return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
};
stats_t s = stats.template compute<Is_even_cols>(
xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
const bool save_z = !Has_subset || row_z > 0;
if (save_z) {
index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ovec z;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
}
z.store_to(params.z, idx_z);
idx_z += VEC_COLS_PER_LDG;
}
}
}
}
}
} // namespace layer_norm
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
bool has_colscale = launch_params.params.colscale != nullptr;
bool has_subset = launch_params.params.x0_subset != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.multi_processor_count * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
});
});
}
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using residual_t = residual_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
bool Has_colscale,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
residual_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
residual_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using residual_t = typename Base::residual_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
// using mask_t = unsigned char;
using mask_t = bool;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) == sizeof(output_t));
static_assert(sizeof(input_t) <= sizeof(residual_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
This diff is collapsed.
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
use core::ffi::{c_int, c_void};
extern "C" {
pub(crate) fn run_ln(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
beta: *const c_void,
dst_add: *const c_void,
dst: *const c_void,
mu: *const c_void,
rsigma: *const c_void,
epsilon: f32,
hidden_size_rounded: u32,
rows: u32,
cols: u32,
multi_processor_count: i32,
wtype: u32,
itype: u32,
rtype: u32,
otype: u32,
ctype: u32,
is_rms_norm: c_int,
);
}
mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, DType, Layout, Result, Shape, Storage, Tensor};
use half::{bf16, f16};
use std::ptr;
fn layer_norm_internal_type(dtype: DType) -> Result<u32> {
let internal_type = match dtype {
DType::F16 => 0,
DType::BF16 => 1,
DType::F32 => 2,
dtype => candle::bail!("dtype {dtype:?} is not supported"),
};
Ok(internal_type)
}
pub struct LayerNorm {
pub epsilon: f32,
pub is_rms_norm: bool,
pub gamma: Tensor,
pub beta: Option<Tensor>,
}
fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
impl LayerNorm {
fn fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
&self,
x: &candle::CudaStorage,
x_l: &Layout,
r: Option<&candle::CudaStorage>,
r_l: Option<&Layout>,
) -> Result<(candle::CudaStorage, Shape)> {
// Assume all tensors are on the same device and take device of x
let dev = x.device();
// Get internal layer norm type id for the given dtype
let layer_norm_type = layer_norm_internal_type(x.dtype())?;
// Make sure that gamma is a CUDA tensor and get the underlying storage
let (g, g_l) = self.gamma.storage_and_layout();
let g = match &*g {
Storage::Cuda(g) => g,
_ => candle::bail!("gamma must be a cuda tensor"),
};
// Get cuda slices for all tensors
let x = x.as_cuda_slice::<T>()?;
let g = g.as_cuda_slice::<T>()?;
// Get cuda views for all tensors
let x = x.slice(x_l.start_offset()..);
let g = g.slice(g_l.start_offset()..);
// Input matrix layout
let rows = x_l.dims()[0];
let cols = x_l.dims()[1];
if !(cols % 8 == 0 && cols <= 8192) {
candle::bail!("hidden size must be % 8 and <= 8192")
}
let x_stride = x_l.stride();
let g_stride = g_l.stride();
let x_rank = x_stride.len();
let g_rank = g_stride.len();
if x_rank != 2 {
candle::bail!("layer-norm expects input tensors of rank 2. Found: {x_rank}")
}
if x_stride[x_rank - 1] != 1 {
candle::bail!("the last dim of x must be contiguous {x_stride:?}")
}
if g_stride[g_rank - 1] != 1 {
candle::bail!("the last dim of g must be contiguous {g_stride:?}")
}
// Round cols to match with the correct kernel
let cols_rounded = if cols <= 1536 {
round_multiple(cols, 256)
} else if cols <= 3072 {
round_multiple(cols, 512)
} else {
round_multiple(cols, 1024)
};
let is_rms_norm = if self.is_rms_norm { 1 } else { 0 };
// If beta is et, get ids device pointer
let b_ptr = if let Some(beta) = &self.beta {
// Make sure that beta is a CUDA tensor and get the underlying storage
let (b, b_l) = beta.storage_and_layout();
let b = match &*b {
Storage::Cuda(b) => b,
_ => candle::bail!("gamma must be a cuda tensor"),
};
let b = b.as_cuda_slice::<T>()?;
let b = b.slice(b_l.start_offset()..);
let b_stride = b_l.stride();
let b_rank = b_stride.len();
if b_stride[b_rank - 1] != 1 {
candle::bail!("the last dim of b must be contiguous {b_stride:?}")
}
*b.device_ptr() as *const core::ffi::c_void
} else {
ptr::null() as *const std::ffi::c_void
};
// If residual is set, get its device pointer
let r_ptr = if let (Some(r), Some(r_l)) = (r, r_l) {
// Check shape
let expected_shape = x_l.shape().dims2()?;
if r_l.shape().dims2()? != expected_shape {
candle::bail!("shape mismatch x {:?} and r {:?}", x_l.shape(), r_l.shape());
}
let r = r.as_cuda_slice::<T>()?;
let r = r.slice(r_l.start_offset()..);
let r_stride = r_l.stride();
let r_rank = r_stride.len();
if r_rank != 2 {
candle::bail!("layer-norm expects input tensors of rank 2. Found: {r_rank}")
}
if r_stride[r_rank - 1] != 1 {
candle::bail!("the last dim of r must be contiguous {r_stride:?}")
}
*r.device_ptr() as *const std::ffi::c_void
} else {
ptr::null() as *const std::ffi::c_void
};
// We will store the results of the residual add next to the main results
// so out has the same shape as inp * 2
let out_shape = Shape::from((rows * 2, cols));
let out = unsafe { dev.alloc::<T>(out_shape.elem_count()) }.w()?;
let dst = out.slice(..rows * cols);
let dst_add = out.slice(rows * cols..);
// Alloc internal buffers
let mu = unsafe { dev.alloc::<f32>(rows) }.w()?;
let rsigma = unsafe { dev.alloc::<f32>(rows) }.w()?;
// Get cuda device pointers from cuda slices
let x_ptr = *x.device_ptr() as *const core::ffi::c_void;
let g_ptr = *g.device_ptr() as *const core::ffi::c_void;
let dst_add_ptr = *dst_add.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let mu_ptr = *mu.device_ptr() as *const core::ffi::c_void;
let rsigma_ptr = *rsigma.device_ptr() as *const core::ffi::c_void;
let multi_processors_count = dev
.attribute(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)
.unwrap();
unsafe {
// Launch Kernel
ffi::run_ln(
x_ptr,
r_ptr,
g_ptr,
b_ptr,
dst_add_ptr,
dst_ptr,
mu_ptr,
rsigma_ptr,
self.epsilon,
cols_rounded as u32,
rows as u32,
cols as u32,
multi_processors_count,
layer_norm_type,
layer_norm_type,
layer_norm_type,
layer_norm_type,
2,
is_rms_norm,
)
}
let out = candle::CudaStorage::wrap_cuda_slice(out, dev.clone());
Ok((out, out_shape))
}
}
impl candle::CustomOp1 for LayerNorm {
fn name(&self) -> &'static str {
"fused-layer-norm"
}
fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for fused-layer-norm")
}
fn cuda_fwd(
&self,
x: &candle::CudaStorage,
x_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match x.dtype() {
DType::F16 => self.fwd::<f16>(x, x_l, None, None),
DType::BF16 => self.fwd::<bf16>(x, x_l, None, None),
DType::F32 => self.fwd::<f32>(x, x_l, None, None),
dt => {
candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})")
}
}
}
}
impl candle::CustomOp2 for LayerNorm {
fn name(&self) -> &'static str {
"fused-layer-norm"
}
fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("no cpu support for fused-layer-norm")
}
fn cuda_fwd(
&self,
x: &candle::CudaStorage,
x_l: &Layout,
r: &candle::CudaStorage,
r_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
match x.dtype() {
DType::F16 => self.fwd::<f16>(x, x_l, Some(r), Some(r_l)),
DType::BF16 => self.fwd::<bf16>(x, x_l, Some(r), Some(r_l)),
DType::F32 => self.fwd::<f32>(x, x_l, Some(r), Some(r_l)),
dt => {
candle::bail!("fused-layer-norm is only supported for f32, f16 and bf16 ({dt:?})")
}
}
}
}
/// Layer Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensor has the same dimensions as `x`
pub fn layer_norm(
x: &Tensor,
gamma: &Tensor,
beta: Option<&Tensor>,
epsilon: f32,
) -> Result<Tensor> {
let op = LayerNorm {
epsilon,
gamma: gamma.clone(),
beta: beta.cloned(),
is_rms_norm: false,
};
let results = x.apply_op1(op)?;
let rows = x.dims()[0];
results.narrow(0, 0, rows)
}
/// Fused Add Layer Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `res` - Residual tensor of rank 2. Will be added to `x` before normalization. Must have
/// the same shape as `x`.
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensors have the same dimensions as `x`
/// First tensor is the result of the normalization, second is the result of the residual add
pub fn fused_add_layer_norm(
x: &Tensor,
res: &Tensor,
gamma: &Tensor,
beta: Option<&Tensor>,
epsilon: f32,
) -> Result<(Tensor, Tensor)> {
let op = LayerNorm {
epsilon,
gamma: gamma.clone(),
beta: beta.cloned(),
is_rms_norm: false,
};
let results = x.apply_op2(&res, op)?;
let rows = x.dims()[0];
Ok((results.narrow(0, 0, rows)?, results.narrow(0, rows, rows)?))
}
/// Layer RMS Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensor has the same dimensions as `x`
pub fn rms_norm(x: &Tensor, gamma: &Tensor, beta: Option<&Tensor>, epsilon: f32) -> Result<Tensor> {
let op = LayerNorm {
epsilon,
gamma: gamma.clone(),
beta: beta.cloned(),
is_rms_norm: true,
};
let results = x.apply_op1(op)?;
let rows = x.dims()[0];
results.narrow(0, 0, rows)
}
/// Fused Add RMS Normalization Layer
///
/// # Arguments
///
/// * `x` - Input tensor of rank 2
/// * `res` - Residual tensor of rank 2. Will be added to `x` before normalization. Must have
/// the same shape as `x`.
/// * `gamma` - Channel scale
/// * `beta` - Channel bias
/// * `epsilon` - A value added to the denominator for numerical stability
///
/// The resulting tensors have the same dimensions as `x`
/// First tensor is the result of the normalization, second is the result of the residual add
pub fn fused_add_rms_norm(
x: &Tensor,
res: &Tensor,
gamma: &Tensor,
beta: Option<&Tensor>,
epsilon: f32,
) -> Result<(Tensor, Tensor)> {
let op = LayerNorm {
epsilon,
gamma: gamma.clone(),
beta: beta.cloned(),
is_rms_norm: true,
};
let results = x.apply_op2(&res, op)?;
let rows = x.dims()[0];
Ok((results.narrow(0, 0, rows)?, results.narrow(0, rows, rows)?))
}
#[cfg(test)]
mod tests {
use super::*;
use candle::{DType, Device};
fn layer_norm_truth(
x: &Tensor,
gamma: &Tensor,
beta: Option<&Tensor>,
epsilon: f64,
rms: bool,
) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_seq_len, hidden_size) = x.shape().dims2()?;
let x = x.to_dtype(internal_dtype)?;
let x = if !rms {
let mean_x = (x.sum_keepdim(1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
let norm_x = (x.sqr()?.sum_keepdim(1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + epsilon)?.sqrt()?)?;
let mut x = x_normed.to_dtype(x_dtype)?.broadcast_mul(gamma)?;
if let Some(beta) = beta {
x = x.broadcast_add(beta)?;
}
Ok(x)
}
fn to_vec2_round(t: Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
let b = 10f32.powi(digits);
let t = t.to_dtype(DType::F32)?.to_vec2::<f32>()?;
let t = t
.iter()
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
.collect();
Ok(t)
}
#[test]
fn test_layer_norm() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let res = layer_norm(&x, &g, Some(&b), 1e-12)?;
let truth = layer_norm_truth(&x, &g, Some(&b), 1e-12, false)?;
assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
Ok(())
}
#[test]
fn test_layer_norm_no_bias() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let res = layer_norm(&x, &g, None, 1e-12)?;
let truth = layer_norm_truth(&x, &g, None, 1e-12, false)?;
assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
Ok(())
}
#[test]
fn test_rms_norm() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let res = rms_norm(&x, &g, Some(&b), 1e-12)?;
let truth = layer_norm_truth(&x, &g, Some(&b), 1e-12, true)?;
assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
Ok(())
}
#[test]
fn test_rms_norm_no_bias() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let res = rms_norm(&x, &g, None, 1e-12)?;
let truth = layer_norm_truth(&x, &g, None, 1e-12, true)?;
assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
Ok(())
}
#[test]
fn test_layer_norm_add() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let (res, res_add) = fused_add_layer_norm(&x, &r, &g, Some(&b), 1e-12)?;
let truth_add = (x + r)?;
let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, false)?;
assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?);
assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
Ok(())
}
#[test]
fn test_rms_norm_add() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let r = Tensor::randn(0., 1., (4, 8), &device)?.to_dtype(DType::F32)?;
let g = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let b = Tensor::randn(0., 1., 8, &device)?.to_dtype(DType::F32)?;
let (res, res_add) = fused_add_rms_norm(&x, &r, &g, Some(&b), 1e-12)?;
let truth_add = (x + r)?;
let truth = layer_norm_truth(&truth_add, &g, Some(&b), 1e-12, true)?;
assert_eq!(to_vec2_round(res_add, 3)?, to_vec2_round(truth_add, 3)?);
assert_eq!(to_vec2_round(res, 3)?, to_vec2_round(truth, 3)?);
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment