This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation.
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. We now support
# implementatino and feature support
standard layernorm2d forward is supported. We use welfold algorithm to update mean/variance block by block. For `N <=4096` case we can compute mean/var/normalize within one loop, we call it `one-pass`. For large N case, since the register usage is quite big to compute mean/var while keep inside register for later normalization, we first compuet mean/var block-by-block, then load input another time to compute the normalization. We call it `two-pass`.
## mean/variance save
In training case the mean/variance need to store out (TBD, not supported yet)
## prenorm/postnorm

since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite useful in LLM blocks, this example also support it. Note that prenorm/postnorm always need to fuse a `shortcut` before the actual layernorm computation, the only difference is weather store the added element to global, where prenorm need store out. You can use `-fadd=1` to test prenorm(pre-add+store), or `-fadd=2` to test postnorm(pre-add)
## build
## build
```
```
...
@@ -15,8 +27,15 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
...
@@ -15,8 +27,15 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
```
```
args:
args:
-m m dimension (default:3328)
-m m dimension (default:3328)
-n m dimension (default:4096)
-n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-e epsilon (default:1e-5)
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
-v cpu validation or not (default:1)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
-kname print kernel name or not (default:1)
-prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto)