Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ead5167a
Commit
ead5167a
authored
Nov 01, 2024
by
dummycoderfe
Browse files
merge develop
parents
da1a2829
03c6448b
Changes
137
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
355 additions
and
269 deletions
+355
-269
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp
+0
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
...ernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
+0
-13
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
+0
-12
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
...ernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
+0
-12
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+251
-28
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+23
-80
example/ck_tile/02_layernorm2d/misc/dquant.png
example/ck_tile/02_layernorm2d/misc/dquant.png
+0
-0
example/ck_tile/02_layernorm2d/misc/pnorm.png
example/ck_tile/02_layernorm2d/misc/pnorm.png
+0
-0
example/ck_tile/02_layernorm2d/script/perf_test.sh
example/ck_tile/02_layernorm2d/script/perf_test.sh
+35
-36
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+30
-27
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
+1
-1
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
+1
-8
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
+1
-1
example/ck_tile/10_rmsnorm2d/script/perf_test.sh
example/ck_tile/10_rmsnorm2d/script/perf_test.sh
+2
-3
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
+1
-2
example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
...le/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
+3
-3
example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp
...d_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp
+4
-4
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
...orm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
+1
-8
example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh
example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh
+2
-3
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp
deleted
100644 → 0
View file @
da1a2829
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
da1a2829
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
deleted
100644 → 0
View file @
da1a2829
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
da1a2829
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
deleted
100644 → 0
View file @
da1a2829
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
ead5167a
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring>
#include <cstring>
// different threshold for different dtype
// different threshold for different dtype
...
@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[])
...
@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[])
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec_i"
,
"fp16"
,
"input precision"
)
.
insert
(
"prec_o"
,
"auto"
,
"output precision, set auto will be the same as input"
)
.
insert
(
"prec_sx"
,
"auto"
,
"output quant scale type, set auto will use fp32. used when fquant=1"
)
.
insert
(
"prec_sy"
,
"auto"
,
"output quant scale type, set auto will use fp32. used when fquant=1 or 2"
)
.
insert
(
"fadd"
,
"0"
,
"fused-add, 0:no fused add, 1:preadd+store, 2:preadd only"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[])
...
@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
template
<
typename
DataType
,
bool
SaveMeanVar
>
template
<
typename
InDataType
,
typename
OutDataType
,
typename
XScaleDataType
,
typename
YScaleDataType
,
bool
SaveMeanVar
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
...
@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
n
;
stride
=
n
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sx
==
"auto"
)
{
prec_sx
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
if
(
fused_quant
==
1
&&
prec_o
!=
"int8"
)
{
std
::
cout
<<
"if fused_quant is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
return
false
;
}
assert
(
stride
>=
n
);
assert
(
stride
>=
n
);
using
TypeConfig
=
LayerNormTypeConfig
<
DataType
>
;
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
OutDataType
,
XScaleDataType
,
YScale
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
using
MeanDataType
=
using
MeanDataType
=
std
::
conditional_t
<
SaveMeanVar
,
typename
TypeConfig
::
MeanDataType
,
ck_tile
::
null_type
>
;
std
::
conditional_t
<
SaveMeanVar
,
typename
TypeConfig
::
MeanDataType
,
ck_tile
::
null_type
>
;
...
@@ -73,13 +112,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -73,13 +112,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host
({
n
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host_dev
({
n
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
...
@@ -87,22 +136,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -87,22 +136,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_scale_buf
(
x_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_residual_buf
(
x_residual_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
x_scale_buf
.
ToDevice
(
x_scale_host
.
data
());
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_o
)
{
base_str
+=
"|"
+
prec_o
;
}
if
(
fused_quant
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_sy
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
data_type
,
SaveMeanVar
};
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
nullptr
,
fused_add
==
1
?
y_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
fused_quant
!=
0
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
// p_mean, unsupported yet
nullptr
,
// p_invStd, unsupported yet
epsilon
,
epsilon
,
m
,
m
,
n
,
n
,
...
@@ -111,6 +187,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -111,6 +187,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
layernorm2d_fwd
(
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
...
@@ -122,6 +204,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -122,6 +204,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
// reference
// reference
if
(
fused_add
!=
0
)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std
::
transform
(
x_host
.
mData
.
cbegin
(),
x_host
.
mData
.
cend
(),
x_residual_host
.
mData
.
cbegin
(),
x_host
.
mData
.
begin
(),
[](
auto
x_
,
auto
r_
)
{
auto
o_
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_
)
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
r_
);
return
ck_tile
::
type_convert
<
XDataType
>
(
o_
);
});
}
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
...
@@ -131,13 +228,83 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -131,13 +228,83 @@ bool run(const ck_tile::ArgParser& arg_parser)
InvStdDataType
>
(
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
if
(
fused_quant
!=
0
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
if
(
fused_quant
==
1
)
{
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
// input smooth outlier
acc_
(
m_
,
n_
)
=
acc_
(
m_
,
n_
)
*
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_scale_host
(
n_
));
}
}
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
ck_tile
::
type_convert
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
};
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
,
dquant_functor
);
}
else
{
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
}
y_buf
.
FromDevice
(
y_host_dev
.
data
());
y_buf
.
FromDevice
(
y_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
stride
,
1
});
if
(
fused_add
==
1
)
{
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
if
(
stride
==
n
)
{
{
pass
=
ck_tile
::
check_err
(
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev
,
x_host
,
std
::
string
(
"ADD Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
else
else
{
{
...
@@ -153,8 +320,31 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -153,8 +320,31 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
string
(
"] Error: Incorrect results!"
),
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
rtol
,
atol
);
atol
);
if
(
fused_add
==
1
)
{
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
y_residual_host_dev
.
begin
()
+
i_r
*
stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
x_host
.
begin
()
+
i_r
*
stride
,
x_host
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
y_residual_host_ref_row
,
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
}
}
if
(
fused_quant
==
1
)
{
y_scale_buf
.
FromDevice
(
y_scale_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
y_scale_host_dev
,
y_scale_host_ref
,
std
::
string
(
"SCALE Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
...
@@ -168,23 +358,56 @@ int main(int argc, char* argv[])
...
@@ -168,23 +358,56 @@ int main(int argc, char* argv[])
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
if
(
data_type
==
"fp16"
&&
save_mv
)
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sx
==
"auto"
)
{
{
re
turn
run
<
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
p
re
c_sx
=
"fp32"
;
}
}
else
if
(
data_type
==
"fp16"
&&
!
save_mv
)
if
(
prec_sy
==
"auto"
)
{
{
re
turn
run
<
ck_tile
::
half_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
p
re
c_sy
=
"fp32"
;
}
}
else
if
(
data_type
==
"bf16"
&&
save_mv
)
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
// no dynamic quant case
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
{
return
run
<
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
float
,
float
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
&&
!
save_mv
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
{
return
run
<
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
ead5167a
...
@@ -8,31 +8,35 @@
...
@@ -8,31 +8,35 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
#include <string>
template
<
typename
DataType
>
template
<
typename
InType
,
typename
OutType
,
typename
XScaleDataType_
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
;
struct
LayerNormTypeConfig
;
template
<
>
template
<
typename
OutType
,
typename
XScaleDataType_
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
,
OutType
,
XScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
XScaleDataType
=
XScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
template
<
>
template
<
typename
OutType
,
typename
XScaleDataType_
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
,
XScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
InvStdDataType
=
ck_tile
::
bf16_t
;
using
InvStdDataType
=
ck_tile
::
bf16_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
XScaleDataType
=
XScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
// runtime args
// runtime args
...
@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
...
@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{
{
};
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
struct
layernorm2d_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Layernorm2dShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_fwd_args
a
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
layernorm2d_fwd_traits
struct
layernorm2d_fwd_traits
{
{
std
::
string
data_type
;
std
::
string
prec_i
;
// input precision
bool
save_mean_var
;
std
::
string
prec_o
;
// output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std
::
string
prec_sx
;
// x-scale, used for [1*N] input smooth quant
std
::
string
prec_sy
;
// y-scale, used for [M*1] output for next layer
bool
save_mean_var
;
//
int
fused_add
;
// 0:no-add, 1:pre-add-store, 2:pre-add
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
};
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/02_layernorm2d/misc/dquant.png
0 → 100644
View file @
ead5167a
36 KB
example/ck_tile/02_layernorm2d/misc/pnorm.png
0 → 100644
View file @
ead5167a
31.4 KB
example/ck_tile/02_layernorm2d/script/perf_test.sh
View file @
ead5167a
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
# run from top of ck folder
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
EXE
=
build/bin/tile_example_layernorm2d_fwd
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
\ No newline at end of file
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
\ No newline at end of file
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
ead5167a
#!/bin/sh
#!/bin/sh
# call from top of CK folder
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
./build/bin/tile_example_layernorm2d_fwd
for
fquant
in
""
"-fquant=1 -prec_o=int8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
for
fadd
in
"0"
"1"
;
do
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done
example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp
View file @
ead5167a
...
@@ -69,7 +69,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -69,7 +69,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Shape
=
ck_tile
::
Rmsnorm2d
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlock
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
using
Problem
=
ck_tile
::
Rmsnorm2dFwdPipelineProblem
<
XDataType
,
GammaDataType
,
GammaDataType
,
ComputeDataType
,
ComputeDataType
,
...
...
example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp
View file @
ead5167a
...
@@ -28,7 +28,6 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
...
@@ -28,7 +28,6 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
rmsnorm2d_fwd_args
a
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
{
{
#if 1
float
r
=
-
1
;
float
r
=
-
1
;
// clang-format off
// clang-format off
// rm rn tm tn vn pd rms 2p
// rm rn tm tn vn pd rms 2p
...
@@ -128,16 +127,12 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
...
@@ -128,16 +127,12 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>>
(
s
,
a
);
r
=
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>>
(
s
,
a
);
}
}
return
r
;
return
r
;
#else
return
rmsnorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
#endif
// clang-format on
// clang-format on
}
}
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
t
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
float
rmsnorm2d_fwd
(
rmsnorm2d_fwd_traits
t
,
rmsnorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
{
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
...
@@ -146,8 +141,6 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile:
...
@@ -146,8 +141,6 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile:
{
{
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
return
rmsnorm2d_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
}
}
if
(
r
<
0
)
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
throw
std
::
runtime_error
(
"Without supported instances!"
);
return
r
;
}
}
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp
View file @
ead5167a
...
@@ -97,7 +97,7 @@ struct rmsnorm2d_fwd_traits_
...
@@ -97,7 +97,7 @@ struct rmsnorm2d_fwd_traits_
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Rmsnorm2d
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlock
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
...
...
example/ck_tile/10_rmsnorm2d/script/perf_test.sh
View file @
ead5167a
#!/bin/sh
# run from top of ck folder
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
build/bin/tile_rmsnorm2d_fwd
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
...
...
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
View file @
ead5167a
#!/bin/sh
#!/bin/sh
# call from top of CK folder
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
./build/bin/tile_rmsnorm2d_fwd
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
View file @
ead5167a
...
@@ -18,7 +18,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
...
@@ -18,7 +18,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
using
BDataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YScaleDataType
=
ck_tile
::
half_
t
;
using
YScaleDataType
=
floa
t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
...
@@ -30,7 +30,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
...
@@ -30,7 +30,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
using
BDataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YScaleDataType
=
ck_tile
::
bf16_
t
;
using
YScaleDataType
=
floa
t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
...
@@ -101,7 +101,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
...
@@ -101,7 +101,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
AddRmsnorm2dRdquant
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlock
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveX
=
kSaveX_
;
static
constexpr
bool
kSaveX
=
kSaveX_
;
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp
View file @
ead5167a
...
@@ -66,7 +66,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -66,7 +66,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
BDataType
=
DataType
;
using
BDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
GammaDataType
=
DataType
;
using
XDataType
=
DataType
;
using
XDataType
=
DataType
;
using
YScaleDataType
=
DataType
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
...
@@ -99,12 +99,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -99,12 +99,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr
bool
kThreePass
=
true
;
constexpr
bool
kThreePass
=
true
;
using
BlockWarps
=
ck_tile
::
sequence
<
2
,
2
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
4
,
1
>
;
using
BlockTile
=
ck_tile
::
sequence
<
2
,
128
>
;
using
BlockTile
=
ck_tile
::
sequence
<
4
,
128
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
WarpTile
=
ck_tile
::
sequence
<
1
,
64
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Shape
=
ck_tile
::
AddRmsnorm2dRdquant
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlock
Shape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Problem
=
ck_tile
::
AddRmsnorm2dRdquantFwdPipelineProblem
<
ADataType
,
using
Problem
=
ck_tile
::
AddRmsnorm2dRdquantFwdPipelineProblem
<
ADataType
,
BDataType
,
BDataType
,
GammaDataType
,
GammaDataType
,
...
...
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp
View file @
ead5167a
...
@@ -28,7 +28,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
...
@@ -28,7 +28,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
add_rmsnorm2d_rdquant_fwd_args
a
,
add_rmsnorm2d_rdquant_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
{
{
#if 1
float
r
=
-
1
;
float
r
=
-
1
;
// clang-format off
// clang-format off
// rm rn tm tn vn pd x 3p
// rm rn tm tn vn pd x 3p
...
@@ -128,9 +127,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
...
@@ -128,9 +127,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
true
>>
(
s
,
a
);
r
=
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
true
>>
(
s
,
a
);
}
}
return
r
;
return
r
;
#else
return
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
data_type
,
1
,
1
,
2
,
128
,
8
,
true
,
true
,
false
>>
(
s
,
a
);
#endif
// clang-format on
// clang-format on
}
}
...
@@ -139,7 +135,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
...
@@ -139,7 +135,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
{
{
float
r
=
-
1
;
// Only support instance of save_x == true for now
// Only support instance of save_x == true for now
assert
(
t
.
save_x
);
assert
(
t
.
save_x
);
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
...
@@ -150,8 +145,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
...
@@ -150,8 +145,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
{
{
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
return
add_rmsnorm2d_rdquant_fwd_b16_
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
}
}
if
(
r
<
0
)
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
throw
std
::
runtime_error
(
"Without supported instances!"
);
return
r
;
}
}
example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh
View file @
ead5167a
#!/bin/sh
# run from top of ck folder
EXE
=
"
$(
find
.
-name
tile_add_rmsnorm2d_rdquant_fwd
-type
f |
head
-n
1
)
"
EXE
=
build/bin/tile_add_rmsnorm2d_rdquant_fwd
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment