Commit 3141f9f9 authored by carlushuang's avatar carlushuang
Browse files

update some description and fix format

parent 6451d7fa
...@@ -14,7 +14,7 @@ In training case the mean/variance need to store out (TBD, not supported yet) ...@@ -14,7 +14,7 @@ In training case the mean/variance need to store out (TBD, not supported yet)
![](misc/pnorm.png) ![](misc/pnorm.png)
since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example also support this feature. Note that `prenorm`/`postnorm` always need to fuse a `shortcut` before the actual layernorm computation, the only difference is whether to store the added element to global, `prenorm` need this. You can use `-fadd=1` to test `prenorm`(pre-add+store), or `-fadd=2` to test `postnorm`(pre-add) since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out.
## build ## build
``` ```
......
...@@ -32,10 +32,7 @@ auto create_args(int argc, char* argv[]) ...@@ -32,10 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("kname", "1", "print kernel name or not") .insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision") .insert("prec_i", "fp16", "input precision")
.insert("prec_o", "auto", "output precision, set auto will be the same as input") .insert("prec_o", "auto", "output precision, set auto will be the same as input")
.insert( .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
"fadd",
"0",
"fused-add, 0:no fused add, 1:fused-prenorm(preadd+store), 2:fused-postnorm(preadd)")
.insert("fsweep", "0", "fused-sweep") .insert("fsweep", "0", "fused-sweep")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
......
...@@ -10,12 +10,10 @@ namespace ck_tile { ...@@ -10,12 +10,10 @@ namespace ck_tile {
enum class Layernorm2dFusedAddEnum enum class Layernorm2dFusedAddEnum
{ {
NO_ADD = 0, NO_ADD = 0,
// fused add before layernorm (prenorm), and store result to global // fused add before layernorm and store result to global
PRE_ADD_STORE = 1, PRE_ADD_STORE = 1,
PRE_NORM_ADD = PRE_ADD_STORE, // fused add before layernorm, but not store result
// fused add before layernorm (postnorm), but not store result
PRE_ADD = 2, PRE_ADD = 2,
POST_NORM_ADD = PRE_ADD,
}; };
// clang-format off // clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment