Commit abe875d6 authored by rocking's avatar rocking
Browse files

unify layernorm api

parent 02236580
...@@ -5,7 +5,7 @@ message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") ...@@ -5,7 +5,7 @@ message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp) file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp) add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp ${INSTANCE_SRCS}) target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_api.cpp ${INSTANCE_SRCS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
......
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd.hpp"
#include <cstring> #include <cstring>
extern float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -95,18 +92,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -95,18 +92,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
M, M,
N}; N};
float ave_time = .0; float ave_time =
layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
if constexpr(std::is_same<DataType, ck_tile::fp16_t>::value)
{
ave_time =
layernorm2d_fwd_fp16(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
else if constexpr(std::is_same<DataType, float>::value)
{
ave_time =
layernorm2d_fwd_fp32(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding
// compiler
#if 0
if(a.N % 8 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(a, s);
}
else
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(a, s);
}
}
else if(a.N % 4 == 0)
#endif
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(a, s);
}
else
{
return a.N % 2048 == 0 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(a, s);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(a, s);
}
else
{
return a.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(a, s)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(a, s);
}
}
}
else if(t.data_type.compare("fp32") == 0)
{
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<float, 1, 32, 4, false>(a, s)
: run_layernorm<float, 1, 32, 4, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<float, 1, 64, 4, false>(a, s)
: run_layernorm<float, 1, 64, 4, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<float, 2, 64, 4, false>(a, s)
: run_layernorm<float, 2, 64, 4, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<float, 4, 64, 4, false>(a, s)
: run_layernorm<float, 4, 64, 4, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<float, 8, 64, 4, false>(a, s)
: run_layernorm<float, 8, 64, 4, true>(a, s);
}
else
{
return a.N % 2048 == 0 ? run_layernorm<float, 8, 64, 4, false, true>(a, s)
: run_layernorm<float, 8, 64, 4, true, true>(a, s);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<float, 1, 64, 2, false>(a, s)
: run_layernorm<float, 1, 64, 2, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<float, 2, 64, 2, false>(a, s)
: run_layernorm<float, 2, 64, 2, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<float, 4, 64, 2, false>(a, s)
: run_layernorm<float, 4, 64, 2, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<float, 8, 64, 2, false>(a, s)
: run_layernorm<float, 8, 64, 2, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<float, 16, 64, 2, false>(a, s)
: run_layernorm<float, 16, 64, 2, true>(a, s);
}
else
{
return a.N % 2048 == 0 ? run_layernorm<float, 16, 64, 2, false, true>(a, s)
: run_layernorm<float, 16, 64, 2, true, true>(a, s);
}
}
}
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
#if 0
if(param.N % 8 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(param, stream);
}
else
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(param, stream);
}
}
else if(param.N % 4 == 0)
#endif
if(param.N % 4 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(param, stream);
}
else
{
return param.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(param, stream);
}
}
else if(param.N % 2 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(param, stream);
}
else
{
return param.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(param, stream)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(param, stream);
}
}
else
{
throw std::runtime_error("Sequence length sizes not supported!");
}
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
extern template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on
float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
if(param.N % 4 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<float, 1, 32, 4, false>(param, stream)
: run_layernorm<float, 1, 32, 4, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<float, 1, 64, 4, false>(param, stream)
: run_layernorm<float, 1, 64, 4, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<float, 2, 64, 4, false>(param, stream)
: run_layernorm<float, 2, 64, 4, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<float, 4, 64, 4, false>(param, stream)
: run_layernorm<float, 4, 64, 4, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<float, 8, 64, 4, false>(param, stream)
: run_layernorm<float, 8, 64, 4, true>(param, stream);
}
else
{
return param.N % 2048 == 0 ? run_layernorm<float, 8, 64, 4, false, true>(param, stream)
: run_layernorm<float, 8, 64, 4, true, true>(param, stream);
}
}
else if(param.N % 2 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<float, 1, 64, 2, false>(param, stream)
: run_layernorm<float, 1, 64, 2, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<float, 2, 64, 2, false>(param, stream)
: run_layernorm<float, 2, 64, 2, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<float, 4, 64, 2, false>(param, stream)
: run_layernorm<float, 4, 64, 2, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<float, 8, 64, 2, false>(param, stream)
: run_layernorm<float, 8, 64, 2, true>(param, stream);
}
else if(param.N <= 2048)
{
return param.N == 2048 ? run_layernorm<float, 16, 64, 2, false>(param, stream)
: run_layernorm<float, 16, 64, 2, true>(param, stream);
}
else
{
return param.N % 2048 == 0 ? run_layernorm<float, 16, 64, 2, false, true>(param, stream)
: run_layernorm<float, 16, 64, 2, true, true>(param, stream);
}
}
else
{
throw std::runtime_error("Sequence length sizes not supported!");
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment