Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4501f13
Commit
a4501f13
authored
Jan 21, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ck_tile_gemm_policy
parents
c6dcf20d
e7dce4d2
Changes
368
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
384 additions
and
467 deletions
+384
-467
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
+0
-22
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
.../10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
+0
-65
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+280
-54
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
+34
-85
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
+29
-25
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+41
-33
No files found.
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
deleted
100644 → 0
View file @
c6dcf20d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
rmsnorm2d_fwd_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
rmsnorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
Traits_
>
float
rmsnorm2d_fwd_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
typename
RmsnormTypeConfig
<
DataType
>::
XDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
GammaDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
YDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
InvRmsDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveInvRms
,
Traits_
::
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
a4501f13
This diff is collapsed.
Click to expand it.
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
View file @
a4501f13
This diff is collapsed.
Click to expand it.
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
View file @
a4501f13
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
for
fadd
in
"0"
"1"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
a4501f13
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment