Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
019d4b7c
"src/regex.cpp" did not exist on "fb9176a0543e5bb704356e5469cb0a9ac8c9e03a"
Commit
019d4b7c
authored
Jan 17, 2025
by
illsilin
Browse files
merge from public repo
parents
07307ea1
5063a39f
Changes
344
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
385 additions
and
435 deletions
+385
-435
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
+0
-22
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
.../10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
+0
-65
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+280
-54
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
+34
-85
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
+29
-25
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+15
-15
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
.../12_smoothquant/instances/smoothquant_instance_common.hpp
+2
-2
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+14
-14
example/ck_tile/12_smoothquant/smoothquant.hpp
example/ck_tile/12_smoothquant/smoothquant.hpp
+11
-11
No files found.
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp
deleted
100644 → 0
View file @
07307ea1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
rmsnorm2d_fwd_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
rmsnorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
Traits_
>
float
rmsnorm2d_fwd_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
typename
RmsnormTypeConfig
<
DataType
>::
XDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
GammaDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
YDataType
,
typename
RmsnormTypeConfig
<
DataType
>::
InvRmsDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveInvRms
,
Traits_
::
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
PipelineProblem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
019d4b7c
This diff is collapsed.
Click to expand it.
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
View file @
019d4b7c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -8,27 +8,34 @@
...
@@ -8,27 +8,34 @@
#include "ck_tile/ops/rmsnorm2d.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <string>
#include <string>
template
<
typename
DataType
>
template
<
typename
InType
,
typename
OutType
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
>
struct
RmsnormTypeConfig
;
struct
RmsnormTypeConfig
;
template
<
>
template
<
typename
OutType
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
>
struct
RmsnormTypeConfig
<
ck_tile
::
half_t
>
struct
RmsnormTypeConfig
<
ck_tile
::
half_t
,
OutType
,
SmoothScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
InvRmsDataType
=
ck_tile
::
half_t
;
using
InvRmsDataType
=
ck_tile
::
half_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
SmoothScaleDataType
=
SmoothScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
template
<
>
template
<
typename
OutType
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
>
struct
RmsnormTypeConfig
<
ck_tile
::
bf16_t
>
struct
RmsnormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
,
SmoothScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
InvRmsDataType
=
ck_tile
::
bf16_t
;
using
InvRmsDataType
=
ck_tile
::
bf16_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
SmoothScaleDataType
=
SmoothScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
// runtime args
// runtime args
...
@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
...
@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
{
{
};
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
struct
rmsnorm2d_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
rmsnorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
rmsnorm2d_fwd_args
a
);
float
rmsnorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
rmsnorm2d_fwd_args
a
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
rmsnorm2d_fwd_traits
struct
rmsnorm2d_fwd_traits
{
{
std
::
string
data_type
;
std
::
string
prec_i
;
// input precision
std
::
string
prec_o
;
// output precision
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std
::
string
prec_sm
;
// x-scale, used for [1*N] input smooth quant
std
::
string
prec_sy
;
// y-scale, used for [M*1] output for next layer
bool
save_rms
;
bool
save_rms
;
int
fused_add
;
// 0:no-add, 1:pre-add-store, 2:pre-add
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
};
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
,
rmsnorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
,
rmsnorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
View file @
019d4b7c
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
for
fadd
in
"0"
"1"
;
do
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
019d4b7c
This diff is collapsed.
Click to expand it.
example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp
View file @
019d4b7c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include <ck_tile/core.hpp>
#include "smoothquant.hpp"
#include "smoothquant.hpp"
...
@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
...
@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
typename
SmoothquantTypeConfig
<
DataType
>::
XDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
XDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
X
ScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
Smooth
ScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
ComputeDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
ComputeDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
YScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
YScaleDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
typename
SmoothquantTypeConfig
<
DataType
>::
QYDataType
,
...
...
example/ck_tile/12_smoothquant/smoothquant.cpp
View file @
019d4b7c
This diff is collapsed.
Click to expand it.
example/ck_tile/12_smoothquant/smoothquant.hpp
View file @
019d4b7c
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
...
@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
template
<
>
template
<
>
struct
SmoothquantTypeConfig
<
ck_tile
::
half_t
>
struct
SmoothquantTypeConfig
<
ck_tile
::
half_t
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
X
ScaleDataType
=
float
;
using
Smooth
ScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
template
<
>
template
<
>
struct
SmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
struct
SmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
X
ScaleDataType
=
float
;
using
Smooth
ScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
// runtime args
// runtime args
...
...
Prev
1
2
3
4
5
6
7
8
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment