Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1b616990
Commit
1b616990
authored
Feb 05, 2025
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8_uc
parents
af30d6b6
800cf897
Changes
553
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1065 additions
and
396 deletions
+1065
-396
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+3
-2
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+1
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+5
-2
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+38
-31
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+25
-4
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+68
-90
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+13
-0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+13
-0
example/ck_tile/03_gemm/script/run_full_test.sh
example/ck_tile/03_gemm/script/run_full_test.sh
+22
-2
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
+1
-1
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
+35
-0
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+98
-53
example/ck_tile/05_reduce/reduce.cpp
example/ck_tile/05_reduce/reduce.cpp
+1
-1
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
+28
-5
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
+31
-9
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+683
-0
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
+0
-146
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
+0
-22
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
+0
-14
No files found.
Too many changes to show.
To preserve performance only
553 of 553+
files are displayed.
Plain diff
Email patch
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
1b616990
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
;
do
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
fadd
in
"0"
"1"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
...
...
@@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
9120
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
...
...
example/ck_tile/03_gemm/CMakeLists.txt
View file @
1b616990
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_universal
_gemm
EXCLUDE_FROM_ALL universal_gemm.cpp
)
add_executable
(
tile_example_
gemm_
universal EXCLUDE_FROM_ALL universal_gemm.cpp
)
example/ck_tile/03_gemm/README.md
View file @
1b616990
...
...
@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_
mem_pipeline
-j
make tile_example_gemm_
universal
-j
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
This will result in an executable
`build/bin/tile_example_gemm_basic`
&
`build/bin/tile_example_gemm_universal`
## example
```
...
...
@@ -22,6 +22,9 @@ args:
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
1b616990
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
...
...
@@ -9,8 +9,6 @@
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
...
...
@@ -22,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
...
...
@@ -41,40 +35,31 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
@@ -105,4 +90,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
1b616990
...
...
@@ -8,6 +8,27 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
...
...
@@ -54,12 +75,11 @@ using CDataType = Types::CDataType;
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
R
"
,
"B tensor data layout -
Row
by default"
)
.
insert
(
"b_layout"
,
"
C
"
,
"B tensor data layout -
Column
by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
...
...
@@ -68,7 +88,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
1b616990
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
...
...
@@ -64,52 +91,20 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
stride_B
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_C
=
arg_parser
.
get_int
(
"stride_c"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
b_layout
));
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
stride_C
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}));
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
ck_tile
::
host_tensor_descriptor
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
)));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
ck_tile
::
host_tensor_descriptor
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
)));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
// TODO: add different init types
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
...
@@ -133,7 +128,7 @@ int run_gemm_example_with_layouts(int argc,
stride_A
,
stride_B
,
stride_C
,
batch
_size
,
k
batch
,
n_warmup
,
n_repeat
);
...
...
@@ -143,20 +138,29 @@ int run_gemm_example_with_layouts(int argc,
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
...
...
@@ -196,46 +200,20 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
return
pass
;
}
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
example/ck_tile/03_gemm/script/benchmark_basic.sh
0 → 100755
View file @
1b616990
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
0 → 100755
View file @
1b616990
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
example/ck_tile/03_gemm/script/run_full_test.sh
View file @
1b616990
...
...
@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
export
GPU_arch
=
$4
echo
'GPU_arch: '
$GPU_arch
function
print_log_header
(){
rm
-f
$1
;
echo
'On branch '
$3
&>
$1
;
echo
'Node name: '
$4
>>
$1
;
# get GPU architecture and compute units from rocminfo
echo
-n
"GPU_arch: "
>>
$1
;
rocminfo |
grep
"Name:"
|
grep
"gfx"
>>
$1
;
rocminfo |
grep
"Compute Unit:"
>>
$1
;
hipcc
--version
|
grep
-e
'HIP version'
>>
$1
;
echo
'Environment type: '
$2
>>
$1
;
/opt/rocm/bin/amdclang++
--version
|
grep
-e
'InstalledDir'
>>
$1
;
}
# run verification tests
example/ck_tile/03_gemm/script/smoke_test.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
# run performance benchmarks
export
gemm_basic_log
=
"perf_tile_gemm_basic_fp16_
$GPU_arch
.log"
print_log_header
$gemm_basic_log
$env_type
$branch
$host_name
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 |
tee
-a
$gemm_basic_log
# We do not have a performance benchmark for gemm yet. Will add it in the future.
\ No newline at end of file
export
gemm_mem_pipeline_log
=
"perf_tile_gemm_mem_pipeline_fp16_
$GPU_arch
.log"
print_log_header
$gemm_mem_pipeline_log
$env_type
$branch
$host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 |
tee
-a
$gemm_mem_pipeline_log
example/ck_tile/03_gemm/script/smoke_test.sh
→
example/ck_tile/03_gemm/script/smoke_test
_basic
.sh
View file @
1b616990
...
...
@@ -32,4 +32,4 @@ set -x
run_fp16_tests
set
+x
\ No newline at end of file
set
+x
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
0 → 100755
View file @
1b616990
#!/bin/bash
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
k
in
32 64
;
do
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
}
set
-x
run_fp16_tests
set
+x
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
1b616990
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
...
...
@@ -9,20 +9,11 @@
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
...
...
@@ -37,8 +28,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#
el
if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
...
@@ -57,7 +48,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
// ===============================================
...
...
@@ -65,20 +60,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
#endif
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
using
BaseGemmPipeline
=
UNIVERSAL_GEMM_PIPELINE
<
GemmPipelineProblem
>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
...
@@ -87,36 +82,36 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
scheduler
=
GEMM_PIPELINE_SCHEDULER
;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
#endif
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile
::
GemmPipelineScheduler
::
Interwave
,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
#endif
has_hot_loop_v
,
tail_number_v
>>
;
using
UniversalGemmProblem
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
GemmUniversalTraits
,
scheduler
,
has_hot_loop_v
,
tail_number_v
>
;
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
@@ -139,6 +134,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
if
(
has_hot_loop
)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
...
...
@@ -199,6 +209,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
#endif
}
else
{
...
...
@@ -223,4 +234,38 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/05_reduce/reduce.cpp
View file @
1b616990
...
...
@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;
constexpr
ck_tile
::
index_t
kBlockSize
=
512
;
constexpr
ck_tile
::
index_t
kBlockSize
=
256
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
...
...
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
View file @
1b616990
set
(
RMSNORM2D_FWD_KNOWN_APIS
"fwd;bwd"
)
set
(
RMSNORM2D_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
RMSNORM2D_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
RMSNORM2D_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
RMSNORM2D_FWD_ENABLE_APIS
${
RMSNORM2D_FWD_KNOWN_APIS
}
)
endif
()
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
${
RMSNORM2D_FWD_ENABLE_APIS
}
--working_path
${
CMAKE_CURRENT_BINARY_DIR
}
--list_blobs
RESULT_VARIABLE ret
)
if
(
ret AND NOT ret EQUAL 0
)
message
(
FATAL_ERROR
"Fail to generate kernels via Python.
${
ret
}
"
)
endif
()
file
(
STRINGS
${
CMAKE_CURRENT_BINARY_DIR
}
/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS
)
add_custom_command
(
OUTPUT
${
RMSNORM2D_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
${
RMSNORM2D_FWD_ENABLE_APIS
}
--working_path
${
CMAKE_CURRENT_BINARY_DIR
}
--gen_blobs
)
set
(
TILE_RMSNORM2D_FWD
"tile_rmsnorm2d_fwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding
${
TILE_RMSNORM2D_FWD
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
TILE_RMSNORM2D_FWD
}
EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp
)
target_include_directories
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
INSTANCE_SRC
S
}
)
target_sources
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
RMSNORM2D_FWD_GEN_BLOB
S
}
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
--offload-compress
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
View file @
1b616990
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring>
...
...
@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert
(
stride
>=
n
);
using
XDataType
=
DataType
;
using
YDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
InvRmsDataType
=
ck_tile
::
null_type
;
using
XDataType
=
DataType
;
using
YDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
InvRmsDataType
=
ck_tile
::
null_type
;
using
SmoothScaleDataType
=
ck_tile
::
null_type
;
using
YScaleDataType
=
ck_tile
::
null_type
;
using
ComputeDataType
=
float
;
...
...
@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
PipelineTraits
=
ck_tile
::
Rmsnorm2dFwdTraits
<
true
,
// kPadN
false
,
// kSaveInvRms
kTwoPass
,
ck_tile
::
Rmsnorm2dFusedAddEnum
::
NO_ADD
,
// fuse add
ck_tile
::
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
>
;
// fuse quant
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
GammaDataType
,
ComputeDataType
,
YDataType
,
InvRmsDataType
,
SmoothScaleDataType
,
YScaleDataType
,
Shape
,
true
,
// kPadN
false
,
// kSaveInvRms
kTwoPass
>
;
PipelineTraits
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
Problem
>
;
using
Pipeline
=
std
::
conditional_t
<
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
>
;
using
Default2DEpilogueProblem
=
ck_tile
::
Default2DEpilogueProblem
<
ComputeDataType
,
YDataType
,
false
,
PipelineTraits
::
kPadN
,
false
>
;
using
Default2DEpilogue
=
ck_tile
::
Default2DEpilogue
<
Default2DEpilogueProblem
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
,
Default2DEpilogue
>
;
ck_tile
::
Rmsnorm2dFwdHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
nullptr
,
epsilon
,
m
,
n
,
stride
,
stride
,
stride
,
stride
};
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
...
...
example/ck_tile/10_rmsnorm2d/generate.py
0 → 100644
View file @
1b616990
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import
argparse
from
enum
import
IntEnum
from
pathlib
import
Path
import
sys
from
typing
import
List
,
Optional
,
Any
import
functools
import
itertools
import
copy
from
dataclasses
import
dataclass
def
get_if_str
(
idx
,
total
,
lase_else
=
True
):
if
idx
==
0
:
return
'if'
elif
idx
<
total
-
1
:
return
'else if'
else
:
if
lase_else
:
return
'else'
else
:
return
'else if'
FUSED_ADD_ENUM_STR_MAP
=
[
'no'
,
'pras'
,
# pre-norm
'pra'
]
# post-norm
FUSED_FUSED_SWEEP_STR_MAP
=
[
'no'
,
'sdquant'
,
# smooth dynamic quant
'dquant'
]
# dynamic quant (without sm_scale)
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'int8'
:
'ck_tile::int8_t'
,
'fp8'
:
'ck_tile::fp8_t'
}
def
BOOL_MAP
(
b_
)
->
str
:
if
b_
:
return
'true'
else
:
return
'false'
class
rmsnorm_fwd_codegen
:
API_TRAITS_DEFINE
=
"""
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_,
typename YDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
struct rmsnorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
};
template <typename XDataType_,
typename YDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
YDataType_,
SmoothScaleDataType_,
YScaleDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
"""
API_COMMON_HEADER
=
"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <ck_tile/ops/epilogue.hpp>
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
{F_traits_define}
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
{{
using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType;
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass,
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvRmsDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape,
PipelineTraits>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0, DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""
API_BASE
=
"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
{F_traits_define}
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{{
float r = -1;
{F_dispatch}
return r;
}}
"""
INSTANCE_BASE
=
"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_api_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
{F_instance_def}
// clang-format on
"""
API_PER_DTYPE
=
"""
{F_if}(t.prec_i ==
\"
{F_i_type}
\"
&& t.prec_o ==
\"
{F_o_type}
\"
){{
{F_per_n_case}
}}
"""
API_PER_N_CASE
=
"""
{F_if} {F_N_COND} {{
{F_inner_dispatch}
}}
"""
API_INNER_CASE
=
"""
{F_if} {F_VEC_COND}
r={F_instance_func}(s, a);
"""
def
__init__
(
self
,
working_path
,
kernel_filter
):
self
.
working_path
=
working_path
self
.
kernel_filter
=
kernel_filter
class
k_fuesd_add_enum
(
IntEnum
):
F_NO_ADD
=
0
F_PRE_ADD
=
1
F_PRE_ADD_STORE_RESIDUAL
=
2
class
k_fused_sweep_enum
(
IntEnum
):
F_NO_SWEEP
=
0
F_RENORM
=
1
F_DYNAMIC_QUANT
=
2
@
dataclass
class
k_traits
:
F_kPadN
:
bool
F_kSaveMeanInvStd
:
bool
F_kTwoPass
:
bool
F_kFusedAdd
:
Any
F_kFusedQuant
:
Any
@
dataclass
class
k_shape
:
F_BlockTile
:
List
[
int
]
F_WarpPerBlock
:
List
[
int
]
F_WarpTile
:
List
[
int
]
F_Vector_
:
List
[
int
]
@
property
def
F_BlockSize
(
self
)
->
int
:
return
functools
.
reduce
(
lambda
a
,
b
:
a
*
b
,
self
.
F_WarpTile
)
@
dataclass
class
k_problem
:
F_XDataType
:
str
F_GammaDataType
:
str
F_ComputeDataType
:
str
F_YDataType
:
str
F_InvRmsDataType
:
str
F_BlockShape
:
str
F_Traits
:
Any
#k_traits
@
dataclass
class
k_pipeline_one_pass
:
F_Problem
:
Any
#k_problem
@
dataclass
class
k_pipeline_two_pass
:
F_Problem
:
Any
#k_problem
@
dataclass
class
default_2d_epilogue_problem
:
F_AccDataType
:
str
F_ODataType
:
str
F_kPadM
:
bool
F_kPadN
:
bool
@
dataclass
class
default_2d_epilogue
:
F_problem
:
Any
@
dataclass
class
k_kernel
:
F_pipeline
:
Any
F_epilogue
:
Any
@
dataclass
class
h_traits
:
F_XDataType
:
str
F_YDataType
:
str
F_SmoothScaleDataType
:
str
F_YScaleDataType
:
str
F_Repeat_M
:
int
F_Repeat_N
:
int
F_ThreadPerBlock_M
:
int
F_ThreadPerBlock_N
:
int
F_Vector_N
:
int
F_kPadN
:
bool
F_kSaveInvRms
:
bool
F_kTwoPass
:
bool
F_kFusedAdd
:
int
F_kFusedQuant
:
int
@
property
def
trait_name
(
self
)
->
str
:
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_SmoothScaleDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YScaleDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveInvRms
):
5
}
'
t_
+=
f
',
{
BOOL_MAP
(
self
.
F_kTwoPass
):
5
}
,
{
self
.
F_kFusedAdd
:
4
}
,
{
self
.
F_kFusedQuant
:
4
}
'
return
t_
# string when calling this kernel
@
property
def
call_name
(
self
)
->
str
:
return
f
'rmsnorm2d_fwd_<traits_<
{
self
.
trait_name
}
>>'
# string when define this kernel
@
property
def
def_name
(
self
)
->
str
:
return
f
'template float rmsnorm2d_fwd_<traits_<
{
self
.
trait_name
}
>>(const S&, A);'
# this class hold kernel under same source file
@
dataclass
class
h_instance
:
F_DataTypePair
:
str
F_N
:
str
F_add
:
int
F_sweep
:
int
instance_list
:
List
[
Any
]
# List[h_traits]
@
property
def
name
(
self
)
->
str
:
prec_i
,
prec_o
=
self
.
F_DataTypePair
.
split
(
','
)
dtype_str
=
f
'
{
prec_i
}
'
if
prec_i
==
prec_o
else
f
'
{
prec_i
}
_
{
prec_o
}
'
nnn
=
f
'rmsnorm2d_fwd_
{
dtype_str
}
_n
{
self
.
F_N
}
'
if
self
.
F_add
!=
0
:
nnn
=
nnn
+
'_'
+
FUSED_ADD_ENUM_STR_MAP
[
self
.
F_add
]
if
self
.
F_sweep
!=
0
:
nnn
=
nnn
+
'_'
+
FUSED_FUSED_SWEEP_STR_MAP
[
self
.
F_sweep
]
return
nnn
@
property
def
instance_name
(
self
)
->
str
:
return
self
.
name
@
property
def
content
(
self
)
->
str
:
instance_defs
=
''
for
ins
in
self
.
instance_list
:
instance_defs
+=
ins
.
def_name
+
'
\n
'
return
rmsnorm_fwd_codegen
.
INSTANCE_BASE
.
format
(
F_instance_def
=
instance_defs
)
@
property
def
name_api
(
self
)
->
str
:
return
'rmsnorm2d_fwd_api'
@
property
def
name_common_header
(
self
)
->
str
:
return
'rmsnorm2d_fwd_api_common'
@
property
def
content_api
(
self
)
->
str
:
# 1 sort based on dtype
t_dtype_dict
=
dict
()
blobs
=
self
.
get_blobs
()
for
blob
in
blobs
:
if
blob
.
F_DataTypePair
not
in
t_dtype_dict
:
t_dtype_dict
[
blob
.
F_DataTypePair
]
=
{}
if
blob
.
F_N
not
in
t_dtype_dict
[
blob
.
F_DataTypePair
]:
t_dtype_dict
[
blob
.
F_DataTypePair
][
blob
.
F_N
]
=
[]
t_dtype_dict
[
blob
.
F_DataTypePair
][
blob
.
F_N
].
append
(
blob
)
d_str
=
''
for
i_d
,
dtype_
in
enumerate
(
t_dtype_dict
):
blob_per_t
=
t_dtype_dict
[
dtype_
]
n_str
=
''
for
i_n
,
n_
in
enumerate
(
blob_per_t
):
blob_per_n
=
blob_per_t
[
n_
]
inner_str
=
""
for
i_b
,
b_
in
enumerate
(
blob_per_n
):
# generate single kernel instance file
#vec_str = ""
for
i_ins
,
ins
in
enumerate
(
b_
.
instance_list
):
idx_in_n
=
i_b
*
len
(
b_
.
instance_list
)
+
i_ins
len_in_n
=
len
(
blob_per_n
)
*
len
(
b_
.
instance_list
)
# _if = 'if' if i_ins == 0 else 'else if'
if
ins
.
F_kFusedQuant
==
0
:
_sweep_cond
=
't.fused_quant == {f_fused_sweep}'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
)
elif
ins
.
F_kFusedQuant
==
1
:
_sweep_cond
=
't.fused_quant == {f_fused_sweep} && (t.prec_sm ==
\"
{f_sx_type}
\"
&& t.prec_sy ==
\"
{f_sy_type}
\"
)'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
,
f_sx_type
=
ins
.
F_SmoothScaleDataType
,
f_sy_type
=
ins
.
F_YScaleDataType
)
elif
ins
.
F_kFusedQuant
==
2
:
_sweep_cond
=
't.fused_quant == {f_fused_sweep} && (t.prec_sy ==
\"
{f_sy_type}
\"
)'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
,
f_sy_type
=
ins
.
F_YScaleDataType
)
_cond
=
'((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'
.
format
(
f_vec_n
=
ins
.
F_Vector_N
,
f_fused_add
=
ins
.
F_kFusedAdd
,
f_sweep_cond
=
_sweep_cond
)
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
#inner_str = inner_str + vec_str
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
(
i_n
<
len
(
blob_per_t
)
-
1
)
else
''
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
prec_i
,
prec_o
=
dtype_
.
split
(
','
)
d_str
+=
self
.
API_PER_DTYPE
.
format
(
F_if
=
get_if_str
(
i_d
,
len
(
t_dtype_dict
),
False
),
F_i_type
=
prec_i
,
F_o_type
=
prec_o
,
F_per_n_case
=
n_str
)
api_base
=
self
.
API_BASE
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
,
F_dispatch
=
d_str
)
return
api_base
@
property
def
content_common_header
(
self
)
->
str
:
return
self
.
API_COMMON_HEADER
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
)
def
get_blobs
(
self
):
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
,
'fp8'
]
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list
=
[
0
,
1
]
fused_sweep_list
=
[
0
,
1
,
2
]
# NOTE: only single pass can use fused (smooth) dynamic quant
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'6144'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'8192'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
0
,
0
)]}
total_blob
=
list
()
for
hs_key
in
h_trait_dict
:
hs
=
h_trait_dict
[
hs_key
]
current_n
=
hs
[
0
].
F_Repeat_N
*
hs
[
0
].
F_ThreadPerBlock_N
*
hs
[
0
].
F_Vector_N
for
dtype
,
scale_type
,
fused_add
,
fused_quant
in
itertools
.
product
(
dtype_list
,
scale_list
,
fused_add_list
,
fused_sweep_list
):
prec_i
,
prec_o
=
dtype
.
split
(
','
)
scale_sm
,
scale_y
=
scale_type
.
split
(
','
)
if
prec_o
in
dynamic_quant_out_dtype
and
fused_quant
!=
1
and
fused_quant
!=
2
:
continue
# skip non dynamic quant case
if
(
fused_quant
==
1
or
fused_quant
==
2
)
and
hs_key
==
'big'
:
continue
current_hs
=
list
()
for
chs_
in
hs
:
h_
=
copy
.
copy
(
chs_
)
# copy the base instance out
h_
.
F_XDataType
=
prec_i
h_
.
F_YDataType
=
prec_o
h_
.
F_SmoothScaleDataType
=
scale_sm
h_
.
F_YScaleDataType
=
scale_y
h_
.
F_kFusedAdd
=
fused_add
h_
.
F_kFusedQuant
=
fused_quant
current_hs
.
append
(
h_
)
# + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str
=
'big'
if
hs_key
==
'big'
else
current_n
total_blob
.
append
(
h_instance
(
dtype
,
current_n_str
,
fused_add
,
fused_quant
,
current_hs
))
return
total_blob
def
list_blobs
(
self
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
list_p
=
w_p
/
'rmsnorm2d_fwd_blobs.txt'
blobs
=
self
.
get_blobs
()
with
list_p
.
open
(
'w'
)
as
list_f
:
# api related file
list_f
.
write
(
str
(
w_p
/
(
self
.
name_api
+
".cpp"
))
+
"
\n
"
)
list_f
.
write
(
str
(
w_p
/
(
self
.
name_common_header
+
".hpp"
))
+
"
\n
"
)
# kernel instance file
for
b
in
blobs
:
list_f
.
write
(
str
(
w_p
/
(
b
.
name
+
".cpp"
))
+
"
\n
"
)
def
gen_blobs
(
self
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
(
w_p
/
(
self
.
name_api
+
".cpp"
)).
write_text
(
self
.
content_api
)
(
w_p
/
(
self
.
name_common_header
+
".hpp"
)).
write_text
(
self
.
content_common_header
)
blobs
=
self
.
get_blobs
()
for
b
in
blobs
:
(
w_p
/
(
b
.
name
+
".cpp"
)).
write_text
(
b
.
content
)
def
list_blobs
(
args
):
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
if
api
==
'fwd'
:
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
list_blobs
()
def
gen_blobs
(
args
):
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
if
api
==
'fwd'
:
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
gen_blobs
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"generate"
,
description
=
"gen API for CK rmsnorm kernel"
,
)
parser
.
add_argument
(
"-a"
,
"--api"
,
default
=
'fwd[all]'
,
required
=
False
,
help
=
"supply API(s) to generate (default: fwd). separated by comma."
)
# the directory for list_blobs/gen_blobs to write files into
parser
.
add_argument
(
"-w"
,
"--working_path"
,
default
=
"./"
,
required
=
False
,
help
=
"the path where all the blobs are going to be generated"
)
# this script have 2 modes
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
# this is useful in build system like cmake to construct source code dependency, by
# reading the content out of this file
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
# like FA, only need to use this mode
parser
.
add_argument
(
"-l"
,
"--list_blobs"
,
action
=
'store_true'
,
help
=
"list all the kernels to a file, "
)
parser
.
add_argument
(
"-g"
,
"--gen_blobs"
,
action
=
'store_true'
,
help
=
"generate all kernels into different tile"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser
.
add_argument
(
"-f"
,
"--filter"
,
required
=
False
,
help
=
"filter out kernels that need to generate, using fnmatch module"
)
parser
.
add_argument
(
"-t"
,
"--traits"
,
default
=
"all"
,
required
=
False
,
help
=
"enable/disable some feature. default generate all"
)
parser
.
add_argument
(
"-r"
,
"--receipt"
,
default
=
0
,
required
=
False
,
help
=
"codegen receipt."
)
args
=
parser
.
parse_args
()
# print(f'{args.list_blobs}-{args.gen_blobs}')
if
(
args
.
gen_blobs
and
args
.
list_blobs
)
or
((
not
args
.
gen_blobs
)
and
(
not
args
.
list_blobs
)):
print
(
'gen_blobs/list_blobs must specify only one option'
)
sys
.
exit
()
p
=
Path
(
args
.
working_path
)
if
not
p
.
exists
():
p
.
mkdir
()
if
args
.
list_blobs
:
list_blobs
(
args
)
else
:
gen_blobs
(
args
)
\ No newline at end of file
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
deleted
100644 → 0
View file @
af30d6b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
rmsnorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
data_type
>
float
rmsnorm2d_fwd_b16_
(
rmsnorm2d_fwd_traits
/*t*/
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
// clang-format off
// rm rn tm tn vn pd rms 2p
if
(
a
.
n
<=
64
)
{
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
256
)
{
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
512
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
768
)
{
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1024
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1536
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
>
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>>
(
s
,
a
);
}
return
r
;
// clang-format on
}
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
t
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
}
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
}
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
deleted
100644 → 0
View file @
af30d6b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
deleted
100644 → 0
View file @
af30d6b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
deleted
100644 → 0
View file @
af30d6b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
Prev
1
2
3
4
5
6
7
8
9
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment