Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1f131e76
Commit
1f131e76
authored
Jan 10, 2025
by
aska-0096
Browse files
fp8 sanity
parent
487826b3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
50 deletions
+44
-50
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
...orm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
+8
-8
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp
+12
-12
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp
...nces/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp
+4
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp
+12
-13
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp
...nces/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp
+4
-4
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
..._tile/host/reference/reference_rowwise_quantization2d.hpp
+1
-6
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
.../pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
+2
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
...ipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
+1
-1
No files found.
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
View file @
1f131e76
...
@@ -122,9 +122,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
...
@@ -122,9 +122,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
if
(
a
.
n
<
8192
){
if
(
a
.
n
<
8192
){
if
(
t
.
save_x
){
if
(
t
.
save_x
){
if
(
a
.
n
%
8
==
0
)
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>>
(
s
,
a
);
else
else
...
@@ -132,9 +132,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
...
@@ -132,9 +132,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
}
}
else
{
else
{
if
(
a
.
n
%
8
==
0
)
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
2
,
1
,
512
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
512
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
else
...
@@ -143,9 +143,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
...
@@ -143,9 +143,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
}
}
else
{
else
{
if
(
a
.
n
%
8
==
0
)
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
2
,
1
,
512
,
8
,
false
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
512
,
4
,
false
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>>
(
s
,
a
);
else
else
...
@@ -154,9 +154,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
...
@@ -154,9 +154,9 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
}
}
else
if
(
a
.
n
>
8192
)
{
else
if
(
a
.
n
>
8192
)
{
if
(
a
.
n
%
8
==
0
)
if
(
a
.
n
%
8
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
else
if
(
a
.
n
%
4
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
else
if
(
a
.
n
%
2
==
0
)
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
input_data_type
,
quantized_data_type
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>>
(
s
,
a
);
else
else
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp
View file @
1f131e76
...
@@ -6,28 +6,28 @@
...
@@ -6,28 +6,28 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd x 3p
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp
View file @
1f131e76
...
@@ -6,12 +6,12 @@
...
@@ -6,12 +6,12 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd x 3p
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp
View file @
1f131e76
...
@@ -6,29 +6,28 @@
...
@@ -6,29 +6,28 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd x 3p
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
false
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp
View file @
1f131e76
...
@@ -6,12 +6,12 @@
...
@@ -6,12 +6,12 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd x 3p
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
512
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
512
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
View file @
1f131e76
...
@@ -22,12 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
...
@@ -22,12 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
// scale = amax / 127 for int8
// scale = amax / 127 for int8
auto
v_scale
=
type_convert
<
XDataType
>
(
scale_m
(
m
));
auto
v_scale
=
type_convert
<
XDataType
>
(
scale_m
(
m
));
auto
v_qx
=
v_x
/
v_scale
;
auto
v_qx
=
v_x
/
v_scale
;
qx_m_n
(
m
,
n
)
=
saturates
<
QXDataType
>
{}(
v_qx
);
qx_m_n
(
m
,
n
)
=
type_convert
<
QXDataType
>
(
saturates
<
QXDataType
>
{}(
v_qx
));
if
(
m
==
0
&&
n
==
4
)
printf
(
"Qy: %lf, Satruates Qy: %lf
\n
"
,
type_convert
<
float
>
(
v_qx
),
type_convert
<
float
>
(
qx_m_n
(
m
,
n
)));
}
}
};
};
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
View file @
1f131e76
...
@@ -89,7 +89,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
...
@@ -89,7 +89,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto
x
=
tile_elementwise_in
(
auto
x
=
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
return
type_convert
<
ComputeDataType
>
(
a_
+
b_
);
},
},
a
,
a
,
b
);
b
);
...
@@ -157,7 +157,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
...
@@ -157,7 +157,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
sweep_tile
(
qy
,
[
&
,
yscale_
=
yscale
](
auto
idx
)
{
sweep_tile
(
qy
,
[
&
,
yscale_
=
yscale
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale_
[
i_idx
];
auto
qy_
=
y
[
idx
]
/
yscale_
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
});
store_tile
(
qy_window
,
qy
);
store_tile
(
qy_window
,
qy
);
}
}
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
View file @
1f131e76
...
@@ -260,7 +260,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
...
@@ -260,7 +260,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms
[
i_idx
]
*
gamma_
;
auto
y_
=
x_
*
inv_rms
[
i_idx
]
*
gamma_
;
auto
qy_
=
y_
/
yscale
[
i_idx
];
auto
qy_
=
y_
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
});
store_tile
(
qy_window
,
qy
);
store_tile
(
qy_window
,
qy
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment