Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dec32dc6
"vscode:/vscode.git/clone" did not exist on "e6617e3273985bcaad82103f86d72ed0bdbd3f54"
Commit
dec32dc6
authored
Jan 31, 2025
by
ThomasNing
Browse files
Finish the feature and merge with develop on the computeV2
parents
71352c44
c5fff071
Changes
215
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
896 additions
and
445 deletions
+896
-445
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+3
-2
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+78
-89
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+73
-5
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
+28
-5
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
+31
-9
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+683
-0
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
+0
-146
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
+0
-22
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
...norm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
..._rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
+0
-12
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
+0
-22
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
+0
-13
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
...rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
+0
-14
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
dec32dc6
...
@@ -31,7 +31,7 @@ constexpr bool isDoubleSmemBuffer = false;
...
@@ -31,7 +31,7 @@ constexpr bool isDoubleSmemBuffer = false;
constexpr
bool
isDoubleSmemBuffer
=
false
;
constexpr
bool
isDoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV
3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV
4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr
bool
isDoubleSmemBuffer
=
true
;
constexpr
bool
isDoubleSmemBuffer
=
true
;
#else
#else
...
@@ -97,7 +97,8 @@ auto create_args(int argc, char* argv[])
...
@@ -97,7 +97,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"split_k"
,
"1"
,
"splitK value"
);
.
insert
(
"split_k"
,
"1"
,
"splitK value"
)
.
insert
(
"init"
,
"0"
,
"0:random, 1:linear, 2:constant(1)"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
template
<
typename
Layout
>
static
constexpr
inline
auto
is_row_major
(
Layout
layout_
)
{
return
ck_tile
::
bool_constant
<
std
::
is_same_v
<
ck_tile
::
remove_cvref_t
<
decltype
(
layout_
)
>
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>>
{};
}
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
...
@@ -67,53 +94,32 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -67,53 +94,32 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
ck_tile
::
index_t
kbatch
=
arg_parser
.
get_int
(
"split_k"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
n_repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
index_t
init_method
=
arg_parser
.
get_int
(
"init"
);
using
namespace
ck_tile
::
literals
;
stride_A
=
ck_tile
::
get_default_stride
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
));
stride_B
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
));
auto
f_host_tensor_descriptor
=
stride_C
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}));
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1_
uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1_
uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
stride_A
=
f_get_default_stride
(
M
,
K
,
stride_A
,
a_layout
);
stride_B
=
f_get_default_stride
(
K
,
N
,
stride_B
,
b_layout
);
stride_C
=
f_get_default_stride
(
M
,
N
,
stride_C
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
stride_A
,
a_layout
));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
stride_B
,
b_layout
));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
// TODO: add different init types
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
host_tensor_descriptor
(
M
,
K
,
stride_A
,
is_row_major
(
a_layout
)));
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
ck_tile
::
host_tensor_descriptor
(
K
,
N
,
stride_B
,
is_row_major
(
b_layout
)));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{})));
if
(
init_method
==
0
)
{
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
else
if
(
init_method
==
1
)
{
ck_tile
::
FillMonotonicSeq
<
ADataType
>
{}(
a_m_k
);
ck_tile
::
FillMonotonicSeq
<
BDataType
>
{}(
b_k_n
);
}
else
if
(
init_method
==
2
)
{
ck_tile
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck_tile
::
FillConstant
<
BDataType
>
{
1.
f
}(
b_k_n
);
}
else
{
a_m_k
.
SetZero
();
b_k_n
.
SetZero
();
}
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
...
@@ -143,20 +149,29 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -143,20 +149,29 @@ int run_gemm_example_with_layouts(int argc,
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
c_m_n_host_ref
.
SetZero
();
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_gpu_ref
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_C
,
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_C
,
is_row_major
(
CLayout
{}))
)
;
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_gpu_buf_ref
(
c_m_n_gpu_ref
.
get_element_space_size_in_bytes
());
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
...
@@ -196,46 +211,20 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -196,46 +211,20 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
return
pass
;
return
pass
;
}
}
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
// if(a_layout == "R" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
// else if(a_layout == "R" && b_layout == "C")
// {
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
// }
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
// else
// {
// throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
// }
}
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
dec32dc6
...
@@ -28,8 +28,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -28,8 +28,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#endif
#
el
if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
...
@@ -63,6 +63,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -63,6 +63,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
// ===============================================
...
@@ -71,14 +73,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -71,14 +73,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
isDoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
isDoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
isDoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
...
@@ -101,7 +110,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -101,7 +110,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
GemmShape
,
GemmShape
,
Traits
,
GemmUniversal
Traits
,
scheduler
,
scheduler
,
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>
;
tail_number_v
>
;
...
@@ -133,6 +142,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -133,6 +142,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
...
@@ -193,6 +217,16 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -193,6 +217,16 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V2)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
#endif
}
}
else
else
{
{
...
@@ -217,4 +251,38 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -217,4 +251,38 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
View file @
dec32dc6
set
(
RMSNORM2D_FWD_KNOWN_APIS
"fwd;bwd"
)
set
(
RMSNORM2D_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
RMSNORM2D_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
RMSNORM2D_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
RMSNORM2D_FWD_ENABLE_APIS
${
RMSNORM2D_FWD_KNOWN_APIS
}
)
endif
()
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
${
RMSNORM2D_FWD_ENABLE_APIS
}
--working_path
${
CMAKE_CURRENT_BINARY_DIR
}
--list_blobs
RESULT_VARIABLE ret
)
if
(
ret AND NOT ret EQUAL 0
)
message
(
FATAL_ERROR
"Fail to generate kernels via Python.
${
ret
}
"
)
endif
()
file
(
STRINGS
${
CMAKE_CURRENT_BINARY_DIR
}
/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS
)
add_custom_command
(
OUTPUT
${
RMSNORM2D_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
${
RMSNORM2D_FWD_ENABLE_APIS
}
--working_path
${
CMAKE_CURRENT_BINARY_DIR
}
--gen_blobs
)
set
(
TILE_RMSNORM2D_FWD
"tile_rmsnorm2d_fwd"
)
set
(
TILE_RMSNORM2D_FWD
"tile_rmsnorm2d_fwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding
${
TILE_RMSNORM2D_FWD
}
"
)
message
(
"adding
${
TILE_RMSNORM2D_FWD
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
TILE_RMSNORM2D_FWD
}
EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp
)
add_executable
(
${
TILE_RMSNORM2D_FWD
}
EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp
)
target_include_directories
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_include_directories
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
INSTANCE_SRC
S
}
)
target_sources
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
RMSNORM2D_FWD_GEN_BLOB
S
}
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
--offload-compress
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
View file @
dec32dc6
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring>
#include <cstring>
...
@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert
(
stride
>=
n
);
assert
(
stride
>=
n
);
using
XDataType
=
DataType
;
using
XDataType
=
DataType
;
using
YDataType
=
DataType
;
using
YDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
InvRmsDataType
=
ck_tile
::
null_type
;
using
InvRmsDataType
=
ck_tile
::
null_type
;
using
SmoothScaleDataType
=
ck_tile
::
null_type
;
using
YScaleDataType
=
ck_tile
::
null_type
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
...
@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
PipelineTraits
=
ck_tile
::
Rmsnorm2dFwdTraits
<
true
,
// kPadN
false
,
// kSaveInvRms
kTwoPass
,
ck_tile
::
Rmsnorm2dFusedAddEnum
::
NO_ADD
,
// fuse add
ck_tile
::
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
>
;
// fuse quant
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
GammaDataType
,
GammaDataType
,
ComputeDataType
,
ComputeDataType
,
YDataType
,
YDataType
,
InvRmsDataType
,
InvRmsDataType
,
SmoothScaleDataType
,
YScaleDataType
,
Shape
,
Shape
,
true
,
// kPadN
PipelineTraits
>
;
false
,
// kSaveInvRms
kTwoPass
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
Problem
>
;
using
OnePassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineOnePass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
Problem
>
;
using
TwoPassPipeline
=
ck_tile
::
Rmsnorm2dFwdPipelineTwoPass
<
Problem
>
;
using
Pipeline
=
std
::
conditional_t
<
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Pipeline
=
std
::
conditional_t
<
kTwoPass
,
TwoPassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
>
;
using
Default2DEpilogueProblem
=
ck_tile
::
Default2DEpilogueProblem
<
ComputeDataType
,
YDataType
,
false
,
PipelineTraits
::
kPadN
,
false
>
;
using
Default2DEpilogue
=
ck_tile
::
Default2DEpilogue
<
Default2DEpilogueProblem
>
;
using
Kernel
=
ck_tile
::
Rmsnorm2dFwd
<
Pipeline
,
Default2DEpilogue
>
;
ck_tile
::
Rmsnorm2dFwdHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
ck_tile
::
Rmsnorm2dFwdHostArgs
args
{
x_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
nullptr
,
nullptr
,
epsilon
,
epsilon
,
m
,
m
,
n
,
n
,
stride
,
stride
,
stride
,
stride
};
stride
};
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
...
...
example/ck_tile/10_rmsnorm2d/generate.py
0 → 100644
View file @
dec32dc6
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import
argparse
from
enum
import
IntEnum
from
pathlib
import
Path
import
sys
from
typing
import
List
,
Optional
,
Any
import
functools
import
itertools
import
copy
from
dataclasses
import
dataclass
def
get_if_str
(
idx
,
total
,
lase_else
=
True
):
if
idx
==
0
:
return
'if'
elif
idx
<
total
-
1
:
return
'else if'
else
:
if
lase_else
:
return
'else'
else
:
return
'else if'
FUSED_ADD_ENUM_STR_MAP
=
[
'no'
,
'pras'
,
# pre-norm
'pra'
]
# post-norm
FUSED_FUSED_SWEEP_STR_MAP
=
[
'no'
,
'sdquant'
,
# smooth dynamic quant
'dquant'
]
# dynamic quant (without sm_scale)
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'int8'
:
'ck_tile::int8_t'
,
'fp8'
:
'ck_tile::fp8_t'
}
def
BOOL_MAP
(
b_
)
->
str
:
if
b_
:
return
'true'
else
:
return
'false'
class
rmsnorm_fwd_codegen
:
API_TRAITS_DEFINE
=
"""
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename XDataType_,
typename YDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
struct rmsnorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
};
template <typename XDataType_,
typename YDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
YDataType_,
SmoothScaleDataType_,
YScaleDataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
"""
API_COMMON_HEADER
=
"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <ck_tile/ops/epilogue.hpp>
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
{F_traits_define}
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
{{
using XDataType = typename Traits_::XDataType;
using YDataType = typename Traits_::YDataType;
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
using YScaleDataType = typename Traits_::YScaleDataType;
using ComputeDataType = typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass,
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvRmsDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
typename Traits_::Shape,
PipelineTraits>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0, DynamicQuantEpilogue, Default2DEpilogue>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""
API_BASE
=
"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
{F_traits_define}
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{{
float r = -1;
{F_dispatch}
return r;
}}
"""
INSTANCE_BASE
=
"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_api_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
{F_instance_def}
// clang-format on
"""
API_PER_DTYPE
=
"""
{F_if}(t.prec_i ==
\"
{F_i_type}
\"
&& t.prec_o ==
\"
{F_o_type}
\"
){{
{F_per_n_case}
}}
"""
API_PER_N_CASE
=
"""
{F_if} {F_N_COND} {{
{F_inner_dispatch}
}}
"""
API_INNER_CASE
=
"""
{F_if} {F_VEC_COND}
r={F_instance_func}(s, a);
"""
def
__init__
(
self
,
working_path
,
kernel_filter
):
self
.
working_path
=
working_path
self
.
kernel_filter
=
kernel_filter
class
k_fuesd_add_enum
(
IntEnum
):
F_NO_ADD
=
0
F_PRE_ADD
=
1
F_PRE_ADD_STORE_RESIDUAL
=
2
class
k_fused_sweep_enum
(
IntEnum
):
F_NO_SWEEP
=
0
F_RENORM
=
1
F_DYNAMIC_QUANT
=
2
@
dataclass
class
k_traits
:
F_kPadN
:
bool
F_kSaveMeanInvStd
:
bool
F_kTwoPass
:
bool
F_kFusedAdd
:
Any
F_kFusedQuant
:
Any
@
dataclass
class
k_shape
:
F_BlockTile
:
List
[
int
]
F_WarpPerBlock
:
List
[
int
]
F_WarpTile
:
List
[
int
]
F_Vector_
:
List
[
int
]
@
property
def
F_BlockSize
(
self
)
->
int
:
return
functools
.
reduce
(
lambda
a
,
b
:
a
*
b
,
self
.
F_WarpTile
)
@
dataclass
class
k_problem
:
F_XDataType
:
str
F_GammaDataType
:
str
F_ComputeDataType
:
str
F_YDataType
:
str
F_InvRmsDataType
:
str
F_BlockShape
:
str
F_Traits
:
Any
#k_traits
@
dataclass
class
k_pipeline_one_pass
:
F_Problem
:
Any
#k_problem
@
dataclass
class
k_pipeline_two_pass
:
F_Problem
:
Any
#k_problem
@
dataclass
class
default_2d_epilogue_problem
:
F_AccDataType
:
str
F_ODataType
:
str
F_kPadM
:
bool
F_kPadN
:
bool
@
dataclass
class
default_2d_epilogue
:
F_problem
:
Any
@
dataclass
class
k_kernel
:
F_pipeline
:
Any
F_epilogue
:
Any
@
dataclass
class
h_traits
:
F_XDataType
:
str
F_YDataType
:
str
F_SmoothScaleDataType
:
str
F_YScaleDataType
:
str
F_Repeat_M
:
int
F_Repeat_N
:
int
F_ThreadPerBlock_M
:
int
F_ThreadPerBlock_N
:
int
F_Vector_N
:
int
F_kPadN
:
bool
F_kSaveInvRms
:
bool
F_kTwoPass
:
bool
F_kFusedAdd
:
int
F_kFusedQuant
:
int
@
property
def
trait_name
(
self
)
->
str
:
t_
=
f
'
{
DATA_TYPE_MAP
[
self
.
F_XDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_SmoothScaleDataType
]
}
,
{
DATA_TYPE_MAP
[
self
.
F_YScaleDataType
]
}
,
{
self
.
F_Repeat_M
:
2
}
,
{
self
.
F_Repeat_N
:
2
}
,
{
self
.
F_ThreadPerBlock_M
:
2
}
,
{
self
.
F_ThreadPerBlock_N
:
4
}
'
t_
+=
f
',
{
self
.
F_Vector_N
:
2
}
,
{
BOOL_MAP
(
self
.
F_kPadN
):
5
}
,
{
BOOL_MAP
(
self
.
F_kSaveInvRms
):
5
}
'
t_
+=
f
',
{
BOOL_MAP
(
self
.
F_kTwoPass
):
5
}
,
{
self
.
F_kFusedAdd
:
4
}
,
{
self
.
F_kFusedQuant
:
4
}
'
return
t_
# string when calling this kernel
@
property
def
call_name
(
self
)
->
str
:
return
f
'rmsnorm2d_fwd_<traits_<
{
self
.
trait_name
}
>>'
# string when define this kernel
@
property
def
def_name
(
self
)
->
str
:
return
f
'template float rmsnorm2d_fwd_<traits_<
{
self
.
trait_name
}
>>(const S&, A);'
# this class hold kernel under same source file
@
dataclass
class
h_instance
:
F_DataTypePair
:
str
F_N
:
str
F_add
:
int
F_sweep
:
int
instance_list
:
List
[
Any
]
# List[h_traits]
@
property
def
name
(
self
)
->
str
:
prec_i
,
prec_o
=
self
.
F_DataTypePair
.
split
(
','
)
dtype_str
=
f
'
{
prec_i
}
'
if
prec_i
==
prec_o
else
f
'
{
prec_i
}
_
{
prec_o
}
'
nnn
=
f
'rmsnorm2d_fwd_
{
dtype_str
}
_n
{
self
.
F_N
}
'
if
self
.
F_add
!=
0
:
nnn
=
nnn
+
'_'
+
FUSED_ADD_ENUM_STR_MAP
[
self
.
F_add
]
if
self
.
F_sweep
!=
0
:
nnn
=
nnn
+
'_'
+
FUSED_FUSED_SWEEP_STR_MAP
[
self
.
F_sweep
]
return
nnn
@
property
def
instance_name
(
self
)
->
str
:
return
self
.
name
@
property
def
content
(
self
)
->
str
:
instance_defs
=
''
for
ins
in
self
.
instance_list
:
instance_defs
+=
ins
.
def_name
+
'
\n
'
return
rmsnorm_fwd_codegen
.
INSTANCE_BASE
.
format
(
F_instance_def
=
instance_defs
)
@
property
def
name_api
(
self
)
->
str
:
return
'rmsnorm2d_fwd_api'
@
property
def
name_common_header
(
self
)
->
str
:
return
'rmsnorm2d_fwd_api_common'
@
property
def
content_api
(
self
)
->
str
:
# 1 sort based on dtype
t_dtype_dict
=
dict
()
blobs
=
self
.
get_blobs
()
for
blob
in
blobs
:
if
blob
.
F_DataTypePair
not
in
t_dtype_dict
:
t_dtype_dict
[
blob
.
F_DataTypePair
]
=
{}
if
blob
.
F_N
not
in
t_dtype_dict
[
blob
.
F_DataTypePair
]:
t_dtype_dict
[
blob
.
F_DataTypePair
][
blob
.
F_N
]
=
[]
t_dtype_dict
[
blob
.
F_DataTypePair
][
blob
.
F_N
].
append
(
blob
)
d_str
=
''
for
i_d
,
dtype_
in
enumerate
(
t_dtype_dict
):
blob_per_t
=
t_dtype_dict
[
dtype_
]
n_str
=
''
for
i_n
,
n_
in
enumerate
(
blob_per_t
):
blob_per_n
=
blob_per_t
[
n_
]
inner_str
=
""
for
i_b
,
b_
in
enumerate
(
blob_per_n
):
# generate single kernel instance file
#vec_str = ""
for
i_ins
,
ins
in
enumerate
(
b_
.
instance_list
):
idx_in_n
=
i_b
*
len
(
b_
.
instance_list
)
+
i_ins
len_in_n
=
len
(
blob_per_n
)
*
len
(
b_
.
instance_list
)
# _if = 'if' if i_ins == 0 else 'else if'
if
ins
.
F_kFusedQuant
==
0
:
_sweep_cond
=
't.fused_quant == {f_fused_sweep}'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
)
elif
ins
.
F_kFusedQuant
==
1
:
_sweep_cond
=
't.fused_quant == {f_fused_sweep} && (t.prec_sm ==
\"
{f_sx_type}
\"
&& t.prec_sy ==
\"
{f_sy_type}
\"
)'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
,
f_sx_type
=
ins
.
F_SmoothScaleDataType
,
f_sy_type
=
ins
.
F_YScaleDataType
)
elif
ins
.
F_kFusedQuant
==
2
:
_sweep_cond
=
't.fused_quant == {f_fused_sweep} && (t.prec_sy ==
\"
{f_sy_type}
\"
)'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
,
f_sy_type
=
ins
.
F_YScaleDataType
)
_cond
=
'((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'
.
format
(
f_vec_n
=
ins
.
F_Vector_N
,
f_fused_add
=
ins
.
F_kFusedAdd
,
f_sweep_cond
=
_sweep_cond
)
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
#inner_str = inner_str + vec_str
n_cnd
=
f
'(a.n <=
{
n_
}
)'
if
(
i_n
<
len
(
blob_per_t
)
-
1
)
else
''
n_str
+=
self
.
API_PER_N_CASE
.
format
(
F_if
=
get_if_str
(
i_n
,
len
(
blob_per_t
)),
F_N_COND
=
n_cnd
,
F_inner_dispatch
=
inner_str
)
prec_i
,
prec_o
=
dtype_
.
split
(
','
)
d_str
+=
self
.
API_PER_DTYPE
.
format
(
F_if
=
get_if_str
(
i_d
,
len
(
t_dtype_dict
),
False
),
F_i_type
=
prec_i
,
F_o_type
=
prec_o
,
F_per_n_case
=
n_str
)
api_base
=
self
.
API_BASE
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
,
F_dispatch
=
d_str
)
return
api_base
@
property
def
content_common_header
(
self
)
->
str
:
return
self
.
API_COMMON_HEADER
.
format
(
F_traits_define
=
self
.
API_TRAITS_DEFINE
)
def
get_blobs
(
self
):
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
,
'fp8'
]
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list
=
[
0
,
1
]
fused_sweep_list
=
[
0
,
1
,
2
]
# NOTE: only single pass can use fused (smooth) dynamic quant
# rm rn tm tn vn pd mv 2p add sweep
h_trait_dict
=
{
'64'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
8
,
8
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'128'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
16
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'256'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'512'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'768'
:
[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
4
,
64
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
12
,
4
,
64
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1024'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
2
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
2
,
128
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'1536'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
4
,
64
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
2
,
128
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'2048'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
1
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
256
,
1
,
True
,
False
,
False
,
0
,
0
)],
'3072'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
128
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
256
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'4096'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'6144'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
3
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
6
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'8192'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
8
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
512
,
4
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
2
,
True
,
False
,
False
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
8
,
1
,
1024
,
1
,
True
,
False
,
False
,
0
,
0
)],
'big'
:[
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
256
,
8
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
256
,
4
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
2
,
1
,
1024
,
2
,
True
,
False
,
True
,
0
,
0
),
h_traits
(
'x'
,
'y'
,
'xs'
,
'ys'
,
1
,
4
,
1
,
1024
,
1
,
True
,
False
,
True
,
0
,
0
)]}
total_blob
=
list
()
for
hs_key
in
h_trait_dict
:
hs
=
h_trait_dict
[
hs_key
]
current_n
=
hs
[
0
].
F_Repeat_N
*
hs
[
0
].
F_ThreadPerBlock_N
*
hs
[
0
].
F_Vector_N
for
dtype
,
scale_type
,
fused_add
,
fused_quant
in
itertools
.
product
(
dtype_list
,
scale_list
,
fused_add_list
,
fused_sweep_list
):
prec_i
,
prec_o
=
dtype
.
split
(
','
)
scale_sm
,
scale_y
=
scale_type
.
split
(
','
)
if
prec_o
in
dynamic_quant_out_dtype
and
fused_quant
!=
1
and
fused_quant
!=
2
:
continue
# skip non dynamic quant case
if
(
fused_quant
==
1
or
fused_quant
==
2
)
and
hs_key
==
'big'
:
continue
current_hs
=
list
()
for
chs_
in
hs
:
h_
=
copy
.
copy
(
chs_
)
# copy the base instance out
h_
.
F_XDataType
=
prec_i
h_
.
F_YDataType
=
prec_o
h_
.
F_SmoothScaleDataType
=
scale_sm
h_
.
F_YScaleDataType
=
scale_y
h_
.
F_kFusedAdd
=
fused_add
h_
.
F_kFusedQuant
=
fused_quant
current_hs
.
append
(
h_
)
# + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str
=
'big'
if
hs_key
==
'big'
else
current_n
total_blob
.
append
(
h_instance
(
dtype
,
current_n_str
,
fused_add
,
fused_quant
,
current_hs
))
return
total_blob
def
list_blobs
(
self
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
list_p
=
w_p
/
'rmsnorm2d_fwd_blobs.txt'
blobs
=
self
.
get_blobs
()
with
list_p
.
open
(
'w'
)
as
list_f
:
# api related file
list_f
.
write
(
str
(
w_p
/
(
self
.
name_api
+
".cpp"
))
+
"
\n
"
)
list_f
.
write
(
str
(
w_p
/
(
self
.
name_common_header
+
".hpp"
))
+
"
\n
"
)
# kernel instance file
for
b
in
blobs
:
list_f
.
write
(
str
(
w_p
/
(
b
.
name
+
".cpp"
))
+
"
\n
"
)
def
gen_blobs
(
self
)
->
None
:
w_p
=
Path
(
self
.
working_path
)
(
w_p
/
(
self
.
name_api
+
".cpp"
)).
write_text
(
self
.
content_api
)
(
w_p
/
(
self
.
name_common_header
+
".hpp"
)).
write_text
(
self
.
content_common_header
)
blobs
=
self
.
get_blobs
()
for
b
in
blobs
:
(
w_p
/
(
b
.
name
+
".cpp"
)).
write_text
(
b
.
content
)
def
list_blobs
(
args
):
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
if
api
==
'fwd'
:
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
list_blobs
()
def
gen_blobs
(
args
):
api_list
=
args
.
api
.
split
(
','
)
for
api
in
api_list
:
if
api
==
'fwd'
:
rmsnorm_fwd_codegen
(
args
.
working_path
,
args
.
filter
).
gen_blobs
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"generate"
,
description
=
"gen API for CK rmsnorm kernel"
,
)
parser
.
add_argument
(
"-a"
,
"--api"
,
default
=
'fwd[all]'
,
required
=
False
,
help
=
"supply API(s) to generate (default: fwd). separated by comma."
)
# the directory for list_blobs/gen_blobs to write files into
parser
.
add_argument
(
"-w"
,
"--working_path"
,
default
=
"./"
,
required
=
False
,
help
=
"the path where all the blobs are going to be generated"
)
# this script have 2 modes
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
# this is useful in build system like cmake to construct source code dependency, by
# reading the content out of this file
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
# like FA, only need to use this mode
parser
.
add_argument
(
"-l"
,
"--list_blobs"
,
action
=
'store_true'
,
help
=
"list all the kernels to a file, "
)
parser
.
add_argument
(
"-g"
,
"--gen_blobs"
,
action
=
'store_true'
,
help
=
"generate all kernels into different tile"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser
.
add_argument
(
"-f"
,
"--filter"
,
required
=
False
,
help
=
"filter out kernels that need to generate, using fnmatch module"
)
parser
.
add_argument
(
"-t"
,
"--traits"
,
default
=
"all"
,
required
=
False
,
help
=
"enable/disable some feature. default generate all"
)
parser
.
add_argument
(
"-r"
,
"--receipt"
,
default
=
0
,
required
=
False
,
help
=
"codegen receipt."
)
args
=
parser
.
parse_args
()
# print(f'{args.list_blobs}-{args.gen_blobs}')
if
(
args
.
gen_blobs
and
args
.
list_blobs
)
or
((
not
args
.
gen_blobs
)
and
(
not
args
.
list_blobs
)):
print
(
'gen_blobs/list_blobs must specify only one option'
)
sys
.
exit
()
p
=
Path
(
args
.
working_path
)
if
not
p
.
exists
():
p
.
mkdir
()
if
args
.
list_blobs
:
list_blobs
(
args
)
else
:
gen_blobs
(
args
)
\ No newline at end of file
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
rmsnorm2d_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
data_type
>
float
rmsnorm2d_fwd_b16_
(
rmsnorm2d_fwd_traits
/*t*/
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
// clang-format off
// rm rn tm tn vn pd rms 2p
if
(
a
.
n
<=
64
)
{
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
256
)
{
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
512
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
768
)
{
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1024
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1536
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
>
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>>
(
s
,
a
);
else
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>>
(
s
,
a
);
}
return
r
;
// clang-format on
}
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
t
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
}
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
}
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp
deleted
100644 → 0
View file @
71352c44
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
rmsnorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment