Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
487826b3
Commit
487826b3
authored
Jan 10, 2025
by
aska-0096
Browse files
tempsave, fp8 sanity error
parent
50e10656
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
354 additions
and
170 deletions
+354
-170
example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp
...le/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp
+70
-43
example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
...le/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
+33
-6
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
...orm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
+107
-52
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp
+6
-3
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp
+8
-5
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp
+8
-5
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp
...nces/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp
+6
-3
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp
+6
-3
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp
+24
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp
...nces/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp
+8
-5
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp
+8
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp
+6
-3
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp
+8
-5
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp
+8
-5
No files found.
example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp
View file @
487826b3
...
...
@@ -3,7 +3,7 @@
#include <cstring>
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
Input
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-2
;
...
...
@@ -39,6 +39,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"quant"
,
"int8"
,
"precision"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
...
@@ -46,7 +47,7 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
DataType
,
bool
SaveX
>
template
<
typename
InputDataType
,
typename
Quantized
DataType
,
bool
SaveX
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
...
...
@@ -54,16 +55,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
input_data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
quantized_data_type
=
arg_parser
.
get_str
(
"quant"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
using
TypeConfig
=
AddRmsnormRdquantTypeConfig
<
DataType
>
;
using
TypeConfig
=
AddRmsnormRdquantTypeConfig
<
InputDataType
,
Quantized
DataType
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
BDataType
=
typename
TypeConfig
::
BDataType
;
...
...
@@ -102,10 +104,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
b_buf
.
ToDevice
(
b_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
input_data_type
<<
", "
<<
quantized_
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
add_rmsnorm2d_rdquant_fwd_traits
traits
{
data_type
,
SaveX
};
add_rmsnorm2d_rdquant_fwd_traits
traits
{
input_data_type
,
quantized_
data_type
,
SaveX
};
add_rmsnorm2d_rdquant_fwd_args
args
{
a_buf
.
GetDeviceBuffer
(),
b_buf
.
GetDeviceBuffer
(),
...
...
@@ -129,14 +131,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_byte
+=
sizeof
(
XDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
using
InvRmsDataType
=
DataType
;
using
InvRmsDataType
=
Input
DataType
;
// Add
{
...
...
@@ -144,28 +146,36 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_binary_elementwise
<
ADataType
,
BDataType
,
XDataType
,
ComputeDataType
>
(
a_host
,
b_host
,
x_host_ref
,
op
);
x_buf
.
FromDevice
(
x_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
XDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
x_host_dev
,
x_host_ref
,
std
::
string
(
"x Error: Incorrect results!"
),
rtol
,
atol
);
}
else
if
constexpr
(
SaveX
)
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
x_buf
.
FromDevice
(
x_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
XDataType
>
();
if
(
stride
==
n
)
{
std
::
vector
<
QYDataType
>
x_host_dev_row
(
x_host_dev
.
begin
()
+
i_r
*
stride
,
x_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
x_host_ref_row
(
x_host_ref
.
begin
()
+
i_r
*
stride
,
x_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
x_host_dev_row
,
x_host_ref_row
,
std
::
string
(
"x["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
pass
=
ck_tile
::
check_err
(
x_host_dev
,
x_host_ref
,
std
::
string
(
"x Error: Incorrect results!"
),
rtol
,
atol
);
}
else
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
x_host_dev_row
(
x_host_dev
.
begin
()
+
i_r
*
stride
,
x_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
x_host_ref_row
(
x_host_ref
.
begin
()
+
i_r
*
stride
,
x_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
x_host_dev_row
,
x_host_ref_row
,
std
::
string
(
"x["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
}
...
...
@@ -256,23 +266,40 @@ int main(int argc, char* argv[])
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
save_x
=
arg_parser
.
get_int
(
"save_x"
);
if
(
data_type
==
"fp16"
&&
save_x
)
const
std
::
string
input_data_type
=
arg_parser
.
get_str
(
"prec"
);
const
std
::
string
quantized_data_type
=
arg_parser
.
get_str
(
"quant"
);
int
save_x
=
arg_parser
.
get_int
(
"save_x"
);
if
(
input_data_type
==
"fp16"
&&
quantized_data_type
==
"int8"
&&
save_x
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
input_data_type
==
"fp16"
&&
quantized_data_type
==
"int8"
&&
!
save_x
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
input_data_type
==
"bf16"
&&
quantized_data_type
==
"int8"
&&
save_x
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
input_data_type
==
"bf16"
&&
quantized_data_type
==
"int8"
&&
!
save_x
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
input_data_type
==
"fp16"
&&
quantized_data_type
==
"fp8"
&&
save_x
)
{
return
run
<
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"fp16"
&&
!
save_x
)
else
if
(
input_
data_type
==
"fp16"
&&
quantized_data_type
==
"fp8"
&&
!
save_x
)
{
return
run
<
ck_tile
::
half_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
&&
save_x
)
else
if
(
input_
data_type
==
"bf16"
&&
quantized_data_type
==
"fp8"
&&
save_x
)
{
return
run
<
ck_tile
::
bf16_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
&&
!
save_x
)
else
if
(
input_
data_type
==
"bf16"
&&
quantized_data_type
==
"fp8"
&&
!
save_x
)
{
return
run
<
ck_tile
::
bf16_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
View file @
487826b3
...
...
@@ -8,11 +8,11 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp"
#include <string>
template
<
typename
DataType
>
template
<
typename
InputDataType
,
typename
Quantized
DataType
>
struct
AddRmsnormRdquantTypeConfig
;
template
<
>
struct
AddRmsnormRdquantTypeConfig
<
ck_tile
::
half_t
>
struct
AddRmsnormRdquantTypeConfig
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
...
...
@@ -24,7 +24,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
};
template
<
>
struct
AddRmsnormRdquantTypeConfig
<
ck_tile
::
bf16_t
>
struct
AddRmsnormRdquantTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
...
...
@@ -35,13 +35,38 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
using
ComputeDataType
=
float
;
};
template
<
>
struct
AddRmsnormRdquantTypeConfig
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
fp8_t
;
using
ComputeDataType
=
float
;
};
template
<
>
struct
AddRmsnormRdquantTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
fp8_t
;
using
ComputeDataType
=
float
;
};
// runtime args
struct
add_rmsnorm2d_rdquant_fwd_args
:
public
ck_tile
::
AddRmsnorm2dRdquantFwdHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
template
<
typename
InputDataType_
,
typename
QuantizedDataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
...
...
@@ -52,7 +77,8 @@ template <typename DataType_,
bool
kThreePass_
>
struct
add_rmsnorm2d_rdquant_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
using
InputDataType
=
ck_tile
::
remove_cvref_t
<
InputDataType_
>
;
using
QuantizedDataType
=
ck_tile
::
remove_cvref_t
<
QuantizedDataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
...
...
@@ -114,7 +140,8 @@ float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_
// This is the public API, will be generated by script
struct
add_rmsnorm2d_rdquant_fwd_traits
{
std
::
string
data_type
;
std
::
string
input_data_type
;
std
::
string
quantized_data_type
;
bool
save_x
;
};
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
View file @
487826b3
...
...
@@ -4,7 +4,8 @@
#include <ck_tile/core.hpp>
#include "add_rmsnorm2d_rdquant_fwd.hpp"
template
<
typename
DataType_
,
template
<
typename
InputDataType_
,
typename
QuantizedDataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
...
...
@@ -13,7 +14,8 @@ template <typename DataType_,
bool
kPadN_
,
bool
kSaveX_
,
bool
kThreePass_
>
using
trait_
=
add_rmsnorm2d_rdquant_fwd_traits_
<
DataType_
,
using
trait_
=
add_rmsnorm2d_rdquant_fwd_traits_
<
InputDataType_
,
QuantizedDataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
...
...
@@ -23,8 +25,8 @@ using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<DataType_,
kSaveX_
,
kThreePass_
>
;
template
<
typename
data_type
>
float
add_rmsnorm2d_rdquant_fwd_b16_
(
add_rmsnorm2d_rdquant_fwd_traits
/*t*/
,
template
<
typename
input_data_type
,
typename
quantized_
data_type
>
float
add_rmsnorm2d_rdquant_fwd_b16_
(
add_rmsnorm2d_rdquant_fwd_traits
t
,
add_rmsnorm2d_rdquant_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
...
...
@@ -32,109 +34,133 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
// clang-format off
// rm rn tm tn vn pd x 3p
if
(
a
.
n
<=
64
)
{
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
1
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
256
)
{
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
512
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
1
,
4
,
64
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
4
,
64
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
8
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
768
)
{
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
4
,
64
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
6
,
4
,
64
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1024
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
1536
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
4096
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
8192
)
{
if
(
a
.
n
<
8192
){
if
(
t
.
save_x
){
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>>
(
s
,
a
);
}
else
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
else
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
n
>
8192
)
{
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>>
(
s
,
a
);
else
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_
data_type
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>>
(
s
,
a
);
}
return
r
;
// clang-format on
...
...
@@ -144,16 +170,45 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
add_rmsnorm2d_rdquant_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
// Only support instance of save_x == true for now
assert
(
t
.
save_x
);
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
if
(
t
.
input_data_type
.
compare
(
"fp16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"int8"
)
==
0
&&
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
input_data_type
.
compare
(
"fp16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"int8"
)
==
0
&&
!
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
input_data_type
.
compare
(
"bf16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"int8"
)
==
0
&&
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
input_data_type
.
compare
(
"bf16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"int8"
)
==
0
&&
!
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
input_data_type
.
compare
(
"fp16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"fp8"
)
==
0
&&
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
input_data_type
.
compare
(
"fp16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"fp8"
)
==
0
&&
!
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
input_data_type
.
compare
(
"bf16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"fp8"
)
==
0
&&
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
fp16
_t
>
(
t
,
a
,
s
);
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8
_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
else
if
(
t
.
input_data_type
.
compare
(
"bf16"
)
==
0
&&
t
.
quantized_data_type
.
compare
(
"fp8"
)
==
0
&&
!
t
.
save_x
)
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
>
(
t
,
a
,
s
);
}
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp
View file @
487826b3
...
...
@@ -15,8 +15,12 @@ template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
#endif
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp
View file @
487826b3
...
...
@@ -6,8 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,13 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp
View file @
487826b3
...
...
@@ -6,7 +6,10 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp
View file @
487826b3
...
...
@@ -6,8 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp
View file @
487826b3
...
...
@@ -6,7 +6,10 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp
View file @
487826b3
...
...
@@ -6,7 +6,10 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
3
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
6
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
12
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
12
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,29 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp
View file @
487826b3
...
...
@@ -15,8 +15,12 @@ template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
#endif
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
2
,
128
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp
View file @
487826b3
...
...
@@ -6,8 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
2
,
128
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,13 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp
View file @
487826b3
...
...
@@ -6,7 +6,10 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp
View file @
487826b3
...
...
@@ -6,9 +6,12 @@
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment