Commit 4b59b5c9 authored by carlushuang's avatar carlushuang
Browse files

add prenorm/postnorm support, refactor using generate.py

parent 4d5248e2
set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all")
set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS})
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
)
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import sys
from typing import List, Optional, Any
import functools
import itertools
import copy
from dataclasses import dataclass
def get_if_str(idx, total, lase_else = True):
if idx == 0:
return 'if'
elif idx < total - 1:
return 'else if'
else:
if lase_else:
return 'else'
else:
return 'else if'
FUSED_ADD_ENUM_STR_MAP = [
'no',
'pras', # pre-norm
'pra' ] # post-norm
FUSED_FUSED_SWEEP_STR_MAP = [
'no',
'renorm',
'dequant' ]
DATA_TYPE_MAP = {'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'}
def BOOL_MAP(b_) -> str:
if b_:
return 'true'
else:
return 'false'
class layernorm_fwd_codegen:
API_TRAITS_DEFINE = """
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_,
typename YDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedSweep_ = 0>
struct layernorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedSweep = kFusedSweep_;
};
template <typename XDataType_,
typename YDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedSweep_>
using traits_ = layernorm2d_fwd_traits_<XDataType_,
YDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_,
kFusedAdd_,
kFusedSweep_>;
"""
API_COMMON_HEADER = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include <ck_tile/ops/epilogue.hpp>
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;
{F_traits_define}
template <typename Traits_>
float layernorm2d_fwd_(const S& s, A a)
{{
using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType;
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType>::ComputeDataType;
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedSweepEnum>(Traits_::kFusedSweep)>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
typename LayerNormTypeConfig<XDataType, YDataType>::XDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::GammaDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::BetaDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::ComputeDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::YDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::MeanDataType,
typename LayerNormTypeConfig<XDataType, YDataType>::InvStdDataType,
typename Traits_::Shape,
PipelineTraits>;
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Epilogue = Default2DEpilogue;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""
API_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
{F_traits_define}
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{{
float r = -1;
{F_dispatch}
return r;
}}
"""
API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
{F_per_n_case}
}}
"""
API_PER_N_CASE=""" {F_if} {F_N_COND} {{
{F_inner_dispatch}
}}
"""
API_INNER_CASE=""" {F_if} {F_VEC_COND}
r={F_instance_func}(s, a);
"""
INSTANCE_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_api_common.hpp"
// clang-format off
// prec_i prec_o rm rn tm tn vn pd mv 2p add sweep
{F_instance_def}
// clang-format on
"""
def __init__(self, working_path, kernel_filter):
self.working_path = working_path
self.kernel_filter = kernel_filter
class k_fuesd_add_enum(IntEnum):
F_NO_ADD = 0
F_PRE_ADD = 1
F_PRE_ADD_STORE_RESIDUAL = 2
class k_fused_sweep_enum(IntEnum):
F_NO_SWEEP = 0
F_RENORM = 1
F_DYNAMIC_QUANT = 2
@dataclass
class k_traits:
F_kPadN : bool
F_kSaveMeanInvStd : bool
F_kTwoPass : bool
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
F_kFusedSweep : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
@dataclass
class k_shape:
F_BlockTile : List[int]
F_WarpPerBlock : List[int]
F_WarpTile : List[int]
F_Vector_ : List[int]
@property
def F_BlockSize(self) -> int:
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
@dataclass
class k_problem:
F_XDataType : str
F_GammaDataType : str
F_BetaDataType : str
F_ComputeDataType : str
F_YDataType : str
F_MeanDataType : str
F_InvStdDataType : str
F_BlockShape : str
F_Traits : Any #k_traits
@dataclass
class k_pipeline_one_pass:
F_Problem : Any #k_problem
@dataclass
class k_pipeline_two_pass:
F_Problem : Any #k_problem
@dataclass
class default_2d_epilogue_problem:
F_AccDataType : str
F_ODataType : str
F_kPadM : bool
F_kPadN : bool
@dataclass
class default_2d_epilogue:
F_problem : Any
@dataclass
class k_kernel:
F_pipeline : Any
F_epilogue : Any
@dataclass
class h_traits:
F_XDataType : str
F_YDataType : str
F_Repeat_M : int
F_Repeat_N : int
F_ThreadPerBlock_M : int
F_ThreadPerBlock_N : int
F_Vector_N : int
F_kPadN : bool
F_kSaveMeanInvStd_ : bool
F_kTwoPass_ : bool
F_kFusedAdd : int
F_kFusedSweep : int
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedSweep:4}'
return t_
# string when calling this kernel
@property
def call_name(self) -> str:
return f'layernorm2d_fwd_<traits_<{self.trait_name}>>'
# string when define this kernel
@property
def def_name(self) -> str:
return f'template float layernorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
# this class hold kernel under same source file
@dataclass
class h_instance:
F_DataTypePair : str
F_N : str
F_add : int
F_sweep : int
instance_list : List[Any] # List[h_traits]
@property
def name(self) -> str:
prec_i, prec_o = self.F_DataTypePair.split(',')
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
if self.F_add != 0:
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
if self.F_sweep != 0:
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
return nnn
@property
def instance_name(self) ->str:
return self.name
@property
def content(self) ->str:
instance_defs = ''
for ins in self.instance_list:
instance_defs += ins.def_name + '\n'
return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
@property
def name_api(self) -> str:
return 'layernorm2d_fwd_api'
@property
def name_common_header(self) -> str:
return 'layernorm2d_fwd_api_common'
@property
def content_api(self) -> str:
# 1 sort based on dtype
t_dtype_dict = dict()
blobs = self.get_blobs()
for blob in blobs:
if blob.F_DataTypePair not in t_dtype_dict:
t_dtype_dict[blob.F_DataTypePair] = {}
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
d_str = ''
for i_d, dtype_ in enumerate(t_dtype_dict):
blob_per_t = t_dtype_dict[dtype_]
n_str = ''
for i_n, n_ in enumerate(blob_per_t):
blob_per_n = blob_per_t[n_]
inner_str = ""
for i_b, b_ in enumerate(blob_per_n):
# generate single kernel instance file
#vec_str = ""
for i_ins, ins in enumerate(b_.instance_list):
idx_in_n = i_b * len(b_.instance_list) + i_ins
len_in_n = len(blob_per_n) * len(b_.instance_list)
# _if = 'if' if i_ins == 0 else 'else if'
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && (t.fused_sweep == {f_fused_sweep}))'.format(
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
f_fused_sweep = ins.F_kFusedSweep)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str
n_cnd = f'(a.n <= {n_})' if n_ != 'big' else ''
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
prec_i, prec_o = dtype_.split(',')
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
return api_base
@property
def content_common_header(self) -> str:
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
def get_blobs(self):
h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
dtype_list = [('fp16,fp16'), ('bf16,bf16')]
fused_add_list = [0, 1, 2]
fused_sweep_list = [0]
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
'128' : [ h_traits('x', 'y', 1, 1, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
'256' : [ h_traits('x', 'y', 1, 1, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
'512' : [ h_traits('x', 'y', 1, 1, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
'768' : [ h_traits('x', 'y', 1, 3, 4, 64, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 6, 4, 64, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 1, 1, 2, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 2, 128, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 1, 3, 4, 64, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 2, 128, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 1, 1, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 1, 3, 1, 128, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 6, 1, 256, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 1, 2, 1, 256, 8, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 4, True, False, False, 0, 0),
h_traits('x', 'y', 1, 2, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 1, 2, 1, 256, 8, True, False, True, 0, 0),
h_traits('x', 'y', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits('x', 'y', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits('x', 'y', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, fused_add, fused_sweep in itertools.product(dtype_list, fused_add_list, fused_sweep_list):
prec_i, prec_o = dtype.split(',')
current_hs = list()
for chs_ in hs:
h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i
h_.F_YDataType = prec_o
h_.F_kFusedAdd = fused_add
h_.F_kFusedSweep = fused_sweep
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_sweep, current_hs))
return total_blob
def list_blobs(self) -> None:
w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs()
with list_p.open('a') as list_f:
# api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
# kernel instance file
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
def gen_blobs(self) -> None:
w_p = Path(self.working_path)
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
blobs = self.get_blobs()
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
def list_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs()
def gen_blobs(args):
api_list = args.api.split(',')
for api in api_list:
if api == 'fwd':
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK layernorm kernel",
)
parser.add_argument(
"-a",
"--api",
default='fwd[all]',
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
)
# the directory for list_blobs/gen_blobs to write files into
parser.add_argument(
"-w",
"--working_path",
default="./",
required=False,
help="the path where all the blobs are going to be generated"
)
# this script have 2 modes
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
# this is useful in build system like cmake to construct source code dependency, by
# reading the content out of this file
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
# like FA, only need to use this mode
parser.add_argument(
"-l",
"--list_blobs",
action='store_true',
help="list all the kernels to a file, "
)
parser.add_argument(
"-g",
"--gen_blobs",
action='store_true',
help="generate all kernels into different tile"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-t",
"--traits",
default="all",
required=False,
help="enable/disable some feature. default generate all"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt."
)
args = parser.parse_args()
# print(f'{args.list_blobs}-{args.gen_blobs}')
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
print('gen_blobs/list_blobs must specify only one option')
sys.exit()
p = Path(args.working_path)
if not p.exists():
p.mkdir()
if args.list_blobs:
list_blobs(args)
else:
gen_blobs(args)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_>;
template <typename data_type>
float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd mv 2p
if(a.n <= 64) {
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
}
return r;
#else
return layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
#endif
// clang-format on
}
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
return layernorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return layernorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
throw std::runtime_error("Without supported instances!");
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
#if 0
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
#if 0
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment