Commit 50e10656 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 888317e6
......@@ -116,15 +116,25 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, false>>(s, a);
}
else if(a.n > 4096) {
else if(a.n <= 8192) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, true, true>>(s, a);
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 8, true, true, false>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, true, true>>(s, a);
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 8, 1, 256, 4, true, true, false>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, true, true>>(s, a);
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 2, true, true, false>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, true>>(s, a);
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 8, 1, 1024, 1, true, true, false>>(s, a);
}
else if(a.n > 8192) {
if (a.n % 8 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 256, 8, true, true, true>>(s, a);
else if (a.n % 4 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 8, 1, 256, 4, true, true, true>>(s, a);
else if (a.n % 2 == 0)
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 2, true, true, true>>(s, a);
else
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 8, 1, 1024, 1, true, true, true>>(s, a);
}
return r;
// clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on
......@@ -6,9 +6,9 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 8, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 4, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
// clang-format on
......@@ -6,9 +6,9 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 8, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 4, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment