Commit 5cfd751b authored by carlushuang's avatar carlushuang
Browse files

refactor layernorm2d pipeline and add block-per-block utility

parent 68e67701
...@@ -2,184 +2,87 @@ ...@@ -2,184 +2,87 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "layernorm2d_fwd_instance_common.hpp"
template <typename DataType, template <typename data_type>
ck_tile::index_t kNRepeat, float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
ck_tile::index_t kMThreadPerBlock, layernorm2d_fwd_args a,
ck_tile::index_t kNThreadPerBlock, const ck_tile::stream_config& s)
ck_tile::index_t kVectorAccessSize, {
bool kPadN, #if 1
bool kTwoPass = false> float r = -1;
using trait_ = layernorm2d_fwd_traits_<DataType, // clang-format off
kNRepeat, // rm rn tm tn vn pd mv 2p
kMThreadPerBlock, if(a.n <= 64) {
kNThreadPerBlock, r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
kVectorAccessSize, }
kPadN, else if(a.n <= 128) {
false, if (a.n % 2 == 0)
kTwoPass>; r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 4 == 0)
// r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 4, true, false, false>>(s, a);
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 2, true, false, false>>(s, a);
else
r = layernorm2d_fwd_<trait_<data_type, 1, 16, 4, 64, 1, true, false, false>>(s, a);
}
return r;
#else
return layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
#endif
// clang-format on
}
float layernorm2d_fwd(layernorm2d_fwd_traits t, float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a, layernorm2d_fwd_args a,
const ck_tile::stream_config& s) const ck_tile::stream_config& s)
{ {
float r = -1; float r = -1;
if(t.data_type.compare("fp16") == 0) if(t.data_type.compare("fp16") == 0)
{ {
if(a.N % 4 == 0) return layernorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
{
if(a.N <= 128)
{
return a.N == 128
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, false, true>>(s,
a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 4, true, true>>(s,
a);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 2, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 4, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 8, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 4, 64, 2, true>>(s, a);
}
else
{
return layernorm2d_fwd_<trait_<ck_tile::fp16_t, 16, 4, 64, 2, true, true>>(s, a);
}
}
else
{
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 4, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 4, 64, 1, true, true>>(s, a);
}
} }
else if(t.data_type.compare("bf16") == 0) else if(t.data_type.compare("bf16") == 0)
{ {
if(a.N % 4 == 0) return layernorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
{
if(a.N <= 128)
{
return a.N == 128
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, false, true>>(s,
a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 4, true, true>>(s,
a);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 4, 64, 2, true>>(s, a);
}
else
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 4, 64, 2, true, true>>(s, a);
}
}
else
{
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 4, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 4, 64, 1, true, true>>(s, a);
}
} }
if(r < 0) if(r < 0)
throw std::runtime_error("Without supported instances!"); throw std::runtime_error("Without supported instances!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t kNRepeat,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t kkVectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
kNRepeat,
kMThreadPerBlock,
kNThreadPerBlock,
kkVectorAccessSize,
false,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t kNRepeat,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t kVectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
kNRepeat,
kMThreadPerBlock,
kNThreadPerBlock,
kVectorAccessSize,
true,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, true>>(const S&, A);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t kNRepeat,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t kVectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
kNRepeat,
kMThreadPerBlock,
kNThreadPerBlock,
kVectorAccessSize,
false,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
template <ck_tile::index_t kNRepeat,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t kVectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::fp16_t,
kNRepeat,
kMThreadPerBlock,
kNThreadPerBlock,
kVectorAccessSize,
true,
false,
kTwoPass>;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 4, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 4, 64, 8, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 4, true>>(const S&, A);
template float layernorm2d_fwd_<t<1, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 4, 64, 2, true>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 4, 64, 1, true>>(const S&, A);
...@@ -7,36 +7,131 @@ ...@@ -7,36 +7,131 @@
#pragma once #pragma once
#ifndef _MAX2
#define _MAX2(a, b) ((a) > (b) ? (a) : (b))
#endif
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct layernorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
using S = ck_tile::stream_config; using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args; using A = layernorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveMeanInvStd_,
kTwoPass_>;
#include <iostream>
template <typename Traits_> template <typename Traits_>
float layernorm2d_fwd_(const S& s, A a) float layernorm2d_fwd_(const S& s, A a)
{ {
using DataType = typename Traits_::DataType; using DataType = typename Traits_::DataType;
using PipelineProblem = using PipelineProblem = ck_tile::Layernorm2dFwdWarpPerRowProblem<
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType, typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType, typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType, typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType, typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType, typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType, typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType, typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape, typename Traits_::Shape,
Traits_::kPadN, Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>; Traits_::kTwoPass>;
using Pipeline = ck_tile::Layernorm2dFwdWarpPerRowPipeline<PipelineProblem>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a.M);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs( auto kargs = Kernel::MakeKargs(a);
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel( return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
} }
#undef _MAX2
...@@ -23,9 +23,12 @@ auto create_args(int argc, char* argv[]) ...@@ -23,9 +23,12 @@ auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "m dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision") .insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "20", "hot iter");
...@@ -34,18 +37,23 @@ auto create_args(int argc, char* argv[]) ...@@ -34,18 +37,23 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
template <typename DataType> template <typename DataType, bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using TypeConfig = LayerNormTypeConfig<DataType>; using TypeConfig = LayerNormTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
...@@ -53,21 +61,23 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -53,21 +61,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
using GammaDataType = typename TypeConfig::GammaDataType; using GammaDataType = typename TypeConfig::GammaDataType;
using BetaDataType = typename TypeConfig::BetaDataType; using BetaDataType = typename TypeConfig::BetaDataType;
using MeanDataType = ck_tile::null_type; using MeanDataType =
using InvStdDataType = ck_tile::null_type; std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
using InvStdDataType =
std::conditional_t<SaveMeanVar, typename TypeConfig::InvStdDataType, ck_tile::null_type>;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({M, N}); ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({N}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({N}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<YDataType> y_host_ref({M, N}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({M, N}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({M}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
...@@ -82,7 +92,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -82,7 +92,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
layernorm2d_fwd_traits traits{data_type}; std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_fwd_traits traits{data_type, SaveMeanVar};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
...@@ -91,19 +104,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -91,19 +104,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
nullptr, nullptr,
nullptr, nullptr,
epsilon, epsilon,
M, m,
N}; n,
stride};
float ave_time = float ave_time = layernorm2d_fwd(
layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "[" << data_type << "]" std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
<< " m:" << M << ", n:" << N << ", " << ave_time * 1.E6 << " ns, " << gb_per_sec
<< " GB/s" << std::flush;
bool pass = true; bool pass = true;
...@@ -122,8 +134,27 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -122,8 +134,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>(); auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err( if(stride == n)
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); {
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
y_host_dev.begin() + i_r * stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
y_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
...@@ -138,13 +169,22 @@ int main(int argc, char* argv[]) ...@@ -138,13 +169,22 @@ int main(int argc, char* argv[])
return -1; return -1;
const std::string data_type = arg_parser.get_str("prec"); const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16") int save_mv = arg_parser.get_int("save_mv");
if(data_type == "fp16" && save_mv)
{
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16" && !save_mv)
{
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && save_mv)
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
} }
if(data_type == "bf16") else if(data_type == "bf16" && !save_mv)
{ {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -36,58 +36,15 @@ struct LayerNormTypeConfig<ck_tile::bf16_t> ...@@ -36,58 +36,15 @@ struct LayerNormTypeConfig<ck_tile::bf16_t>
}; };
// runtime args // runtime args
struct layernorm2d_fwd_args struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{ {
const void* p_x;
const void* p_gamma;
const void* p_beta;
void* p_y;
void* p_mean;
void* p_invStd;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
}; };
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t kNRepeat,
ck_tile::index_t kMThreadPerBlock,
ck_tile::index_t kNThreadPerBlock,
ck_tile::index_t kVectorAccessSize,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct layernorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t MRepeat = 1;
static_assert(kNThreadPerBlock <= 64, "We only support intra-wave reduction");
static constexpr ck_tile::index_t kNWarpPerBlock = 1;
static constexpr ck_tile::index_t kMWarpPerBlock =
kMThreadPerBlock * kNThreadPerBlock / warpSize;
using thread_tile = ck_tile::sequence<MRepeat, kNRepeat, kVectorAccessSize>;
using warp_tile = ck_tile::sequence<MRepeat * warpSize / kNThreadPerBlock,
kNRepeat * kNThreadPerBlock * kVectorAccessSize>;
using block_tile = ck_tile::sequence<kMWarpPerBlock * MRepeat * warpSize / kNThreadPerBlock,
kNRepeat * kNThreadPerBlock * kVectorAccessSize>;
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct layernorm2d_fwd_traits struct layernorm2d_fwd_traits
{ {
std::string data_type; std::string data_type;
bool save_mean_var;
}; };
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
./bin/tile_example_layernorm2d_fwd -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
./bin/tile_example_layernorm2d_fwd -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
# run from top of ck folder
EXE=build/bin/tile_example_layernorm2d_fwd
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment