Unverified Commit c3a4800c authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] layernorm support fused-quant/fused-add (#1604)

* add prenorm/postnorm support, refactor using generate.py

* update README

* update README

* fix format

* update some description and fix format

* update format

* format

* use non-raw for loading

* format and update n4096

* dynamic-quant ready

* update readme

* support fused dynamic-quant

* update fused-quant, with smooth

* update README

* update args

* update some based on comment
parent 9a8a5213
set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all")
set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS})
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
)
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
......
# Layernorm2D forward
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation.
This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation.
# Implementation and feature support
## welford online algorithm
We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`.
## mean/variance save
In training case the mean/variance need to store out (TBD, not supported yet)
## prenorm/postnorm
![](misc/pnorm.png)
since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default).
## smooth-quant/dynamic-quant
we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen)
![](misc/dquant.png)
```
# assume output int8, hidden_states is [m, n] shape and in fp16/bf16
# [m, 1]
per_token_amax, _ = torch.max(
input=torch.abs(hidden_states),
dim=-1,
keepdim=True
)
per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0
# quant hidden_states
hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8)
return hidden_states, per_token_scale
# hidden_states now is int8 will feed to next layer as intput
# per_token_scale will be used as dequant factor later layer
```
## build
```
......@@ -15,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
-kname print kernel name or not (default:1)
-prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto)
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
```
## limitations
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet.
```
# some case
# standard fp16 layernorm 2d, m=10. n=1024
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1
```
\ No newline at end of file
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_>;
template <typename data_type>
float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd mv 2p
if(a.n <= 64) {
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
}
return r;
#else
return layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
#endif
// clang-format on
}
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
return layernorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return layernorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
throw std::runtime_error("Without supported instances!");
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
#if 0
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
#if 0
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment