Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
96a0d5f6
"src/include/threadwise_gemm.hpp" did not exist on "1cc683a3a3add570d4cde015fa7da7ac5fb87d4d"
Commit
96a0d5f6
authored
Jan 16, 2025
by
illsilin
Browse files
merge from public
parents
bfdc2430
54de3e55
Changes
345
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
885 additions
and
337 deletions
+885
-337
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+36
-11
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+13
-0
example/ck_tile/03_gemm/script/run_full_test.sh
example/ck_tile/03_gemm/script/run_full_test.sh
+22
-2
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
+1
-1
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
+35
-0
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+25
-47
example/ck_tile/05_reduce/reduce.cpp
example/ck_tile/05_reduce/reduce.cpp
+1
-1
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
+27
-4
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
+31
-9
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+681
-0
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
+0
-146
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
+0
-22
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
+0
-13
No files found.
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
96a0d5f6
...
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_warmup
,
int
n_repeat
)
int
n_repeat
)
{
{
gemm_basic_a
rgs
args
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
K
=
K
;
...
@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch
_size
=
arg_parser
.
get_int
(
"
b
"
);
ck_tile
::
index_t
k
batch
=
arg_parser
.
get_int
(
"
split_k
"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
using
namespace
ck_tile
::
literals
;
...
@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A
,
stride_A
,
stride_B
,
stride_B
,
stride_C
,
stride_C
,
batch
_size
,
k
batch
,
n_warmup
,
n_warmup
,
n_repeat
);
n_repeat
);
...
@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_gemm_gpu
<
ADataType
,
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
>
(
CLayout
>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_gpu_buf_ref
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
...
example/ck_tile/03_gemm/script/benchmark_basic.sh
0 → 100755
View file @
96a0d5f6
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
0
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
0 → 100755
View file @
96a0d5f6
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
0
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
example/ck_tile/03_gemm/script/run_full_test.sh
View file @
96a0d5f6
...
@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
...
@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
export
GPU_arch
=
$4
export
GPU_arch
=
$4
echo
'GPU_arch: '
$GPU_arch
echo
'GPU_arch: '
$GPU_arch
function
print_log_header
(){
rm
-f
$1
;
echo
'On branch '
$3
&>
$1
;
echo
'Node name: '
$4
>>
$1
;
# get GPU architecture and compute units from rocminfo
echo
-n
"GPU_arch: "
>>
$1
;
rocminfo |
grep
"Name:"
|
grep
"gfx"
>>
$1
;
rocminfo |
grep
"Compute Unit:"
>>
$1
;
hipcc
--version
|
grep
-e
'HIP version'
>>
$1
;
echo
'Environment type: '
$2
>>
$1
;
/opt/rocm/bin/amdclang++
--version
|
grep
-e
'InstalledDir'
>>
$1
;
}
# run verification tests
# run verification tests
example/ck_tile/03_gemm/script/smoke_test.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
# run performance benchmarks
export
gemm_basic_log
=
"perf_tile_gemm_basic_fp16_
$GPU_arch
.log"
print_log_header
$gemm_basic_log
$env_type
$branch
$host_name
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 |
tee
-a
$gemm_basic_log
# We do not have a performance benchmark for gemm yet. Will add it in the future.
export
gemm_mem_pipeline_log
=
"perf_tile_gemm_mem_pipeline_fp16_
$GPU_arch
.log"
\ No newline at end of file
print_log_header
$gemm_mem_pipeline_log
$env_type
$branch
$host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 |
tee
-a
$gemm_mem_pipeline_log
example/ck_tile/03_gemm/script/smoke_test.sh
→
example/ck_tile/03_gemm/script/smoke_test
_basic
.sh
View file @
96a0d5f6
...
@@ -32,4 +32,4 @@ set -x
...
@@ -32,4 +32,4 @@ set -x
run_fp16_tests
run_fp16_tests
set
+x
set
+x
\ No newline at end of file
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
0 → 100755
View file @
96a0d5f6
#!/bin/bash
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
k
in
32 64
;
do
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
}
set
-x
run_fp16_tests
set
+x
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
96a0d5f6
...
@@ -9,20 +9,11 @@
...
@@ -9,20 +9,11 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
// Memory friendly for Interwave scheduler
...
@@ -71,14 +62,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -71,14 +62,15 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
#endif
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
UNIVERSAL_GEMM_PIPELINE
<
GemmPipelineProblem
>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
@@ -87,36 +79,22 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -87,36 +79,22 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
scheduler
=
GEMM_PIPELINE_SCHEDULER
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
using
UniversalGemmProblem
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
BDataType
,
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
AccDataType
,
#endif
GemmShape
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
Traits
,
BDataType
,
scheduler
,
AccDataType
,
has_hot_loop_v
,
GemmShape
,
tail_number_v
>
;
Traits
,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
ck_tile
::
GemmPipelineScheduler
::
Interwave
,
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
#endif
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
example/ck_tile/05_reduce/reduce.cpp
View file @
96a0d5f6
...
@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>;
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;
// using Vector = ck_tile::sequence<1, 8>;
constexpr
ck_tile
::
index_t
kBlockSize
=
512
;
constexpr
ck_tile
::
index_t
kBlockSize
=
256
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
...
...
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
View file @
96a0d5f6
set
(
RMSNORM2D_FWD_KNOWN_APIS
"fwd;bwd"
)
set
(
RMSNORM2D_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
RMSNORM2D_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
RMSNORM2D_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
RMSNORM2D_FWD_ENABLE_APIS
${
RMSNORM2D_FWD_KNOWN_APIS
}
)
endif
()
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
${
RMSNORM2D_FWD_ENABLE_APIS
}
--working_path
${
CMAKE_CURRENT_BINARY_DIR
}
--list_blobs
RESULT_VARIABLE ret
)
if
(
ret AND NOT ret EQUAL 0
)
message
(
FATAL_ERROR
"Fail to generate kernels via Python.
${
ret
}
"
)
endif
()
file
(
STRINGS
${
CMAKE_CURRENT_BINARY_DIR
}
/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS
)
add_custom_command
(
OUTPUT
${
RMSNORM2D_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
${
RMSNORM2D_FWD_ENABLE_APIS
}
--working_path
${
CMAKE_CURRENT_BINARY_DIR
}
--gen_blobs
)
set
(
TILE_RMSNORM2D_FWD
"tile_rmsnorm2d_fwd"
)
set
(
TILE_RMSNORM2D_FWD
"tile_rmsnorm2d_fwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding
${
TILE_RMSNORM2D_FWD
}
"
)
message
(
"adding
${
TILE_RMSNORM2D_FWD
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
TILE_RMSNORM2D_FWD
}
EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp
)
add_executable
(
${
TILE_RMSNORM2D_FWD
}
EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp
)
target_include_directories
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_include_directories
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
INSTANCE_SRC
S
}
)
target_sources
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
RMSNORM2D_FWD_GEN_BLOB
S
}
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
...
...
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
View file @
96a0d5f6
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring>
#include <cstring>
...
@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert
(
stride
>=
n
);
assert
(
stride
>=
n
);
using
XDataType
=
DataType
;
using
XDataType
=
DataType
;
using
YDataType
=
DataType
;
using
YDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
InvRmsDataType
=
ck_tile
::
null_type
;
using
InvRmsDataType
=
ck_tile
::
null_type
;
using
SmoothScaleDataType
=
ck_tile
::
null_type
;
using
YScaleDataType
=
ck_tile
::
null_type
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
...
@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
PipelineTraits
=
ck_tile
::
Rmsnorm2dFwdTraits
<
true
,
// kPadN
false
,
// kSaveInvRms
kTwoPass
,
ck_tile
::
Rmsnorm2dFusedAddEnum
::
NO_ADD
,
// fuse add
ck_tile
::
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
>
;
// fuse quant
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
GammaDataType
,
GammaDataType
,
ComputeDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
InvRmsDataType
,
InvRmsDataType
,
SmoothScaleDataType
,
YScaleDataType
,
Shape
,
Shape
,
true
,
// kPadN
PipelineTraits
>
;
false
,
// kSaveInvRms
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
Problem
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
Problem
>
;
using
Pipeline
=
std
::
conditional_t
<
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Pipeline
=
std
::
conditional_t
<
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
>
;
using
Default2DEpilogueProblem
=
ck_tile
::
Default2DEpilogueProblem
<
ComputeDataType
,
YDataType
,
false
,
PipelineTraits
::
kPadN
,
false
>
;
using
Default2DEpilogue
=
ck_tile
::
Default2DEpilogue
<
Default2DEpilogueProblem
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
,
Default2DEpilogue
>
;
ck_tile
::
Rmsnorm2dFwdHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
ck_tile
::
Rmsnorm2dFwdHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
nullptr
,
nullptr
,
epsilon
,
epsilon
,
m
,
m
,
n
,
n
,
stride
,
stride
,
stride
,
stride
};
stride
};
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
...
...
example/ck_tile/10_rmsnorm2d/generate.py
0 → 100644
View file @
96a0d5f6
This diff is collapsed.
Click to expand it.
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
rmsnorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
data_type
>
float
rmsnorm2d_fwd_b16_
(
rmsnorm2d_fwd_traits
/*t*/
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
// clang-format off
// rm rn tm tn vn pd rms 2p
if
(
a
.
n
<=
64
)
{
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
256
)
{
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
512
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
768
)
{
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1024
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1536
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
>
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>>
(
s
,
a
);
}
return
r
;
// clang-format on
}
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
t
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
}
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
}
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
deleted
100644 → 0
View file @
bfdc2430
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
Prev
1
2
3
4
5
6
7
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment