Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f20e48f1
Commit
f20e48f1
authored
Nov 05, 2024
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8
parents
b97c6876
0c9012fb
Changes
361
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
974 additions
and
672 deletions
+974
-672
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp
+0
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp
+0
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
+0
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
...ernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
+0
-13
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
+0
-12
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
...ernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
+0
-12
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+251
-28
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+23
-80
example/ck_tile/02_layernorm2d/misc/dquant.png
example/ck_tile/02_layernorm2d/misc/dquant.png
+0
-0
example/ck_tile/02_layernorm2d/misc/pnorm.png
example/ck_tile/02_layernorm2d/misc/pnorm.png
+0
-0
example/ck_tile/02_layernorm2d/script/perf_test.sh
example/ck_tile/02_layernorm2d/script/perf_test.sh
+35
-36
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+30
-27
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+2
-2
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+45
-321
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+26
-6
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
+188
-0
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+217
-0
example/ck_tile/05_reduce/reduce.cpp
example/ck_tile/05_reduce/reduce.cpp
+35
-30
example/ck_tile/05_reduce/reduce.hpp
example/ck_tile/05_reduce/reduce.hpp
+109
-63
example/ck_tile/06_permute/CMakeLists.txt
example/ck_tile/06_permute/CMakeLists.txt
+13
-0
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp
deleted
100644 → 0
View file @
b97c6876
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp
deleted
100644 → 0
View file @
b97c6876
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp
deleted
100644 → 0
View file @
b97c6876
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp
deleted
100644 → 0
View file @
b97c6876
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
deleted
100644 → 0
View file @
b97c6876
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp
deleted
100644 → 0
View file @
b97c6876
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
f20e48f1
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring>
#include <cstring>
// different threshold for different dtype
// different threshold for different dtype
...
@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[])
...
@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[])
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"save_mv"
,
"0"
,
"save mean/variance(invstd) or not. set to 1 in training case"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec_i"
,
"fp16"
,
"input precision"
)
.
insert
(
"prec_o"
,
"auto"
,
"output precision, set auto will be the same as input"
)
.
insert
(
"prec_sx"
,
"auto"
,
"output quant scale type, set auto will use fp32. used when fquant=1"
)
.
insert
(
"prec_sy"
,
"auto"
,
"output quant scale type, set auto will use fp32. used when fquant=1 or 2"
)
.
insert
(
"fadd"
,
"0"
,
"fused-add, 0:no fused add, 1:preadd+store, 2:preadd only"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[])
...
@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
template
<
typename
DataType
,
bool
SaveMeanVar
>
template
<
typename
InDataType
,
typename
OutDataType
,
typename
XScaleDataType
,
typename
YScaleDataType
,
bool
SaveMeanVar
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
...
@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
n
;
stride
=
n
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sx
==
"auto"
)
{
prec_sx
=
"fp32"
;
}
if
(
prec_sy
==
"auto"
)
{
prec_sy
=
"fp32"
;
}
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_add
=
arg_parser
.
get_int
(
"fadd"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
if
(
fused_quant
==
1
&&
prec_o
!=
"int8"
)
{
std
::
cout
<<
"if fused_quant is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
return
false
;
}
assert
(
stride
>=
n
);
assert
(
stride
>=
n
);
using
TypeConfig
=
LayerNormTypeConfig
<
DataType
>
;
using
TypeConfig
=
LayerNormTypeConfig
<
InDataType
,
OutDataType
,
XScaleDataType
,
YScale
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
using
MeanDataType
=
using
MeanDataType
=
std
::
conditional_t
<
SaveMeanVar
,
typename
TypeConfig
::
MeanDataType
,
ck_tile
::
null_type
>
;
std
::
conditional_t
<
SaveMeanVar
,
typename
TypeConfig
::
MeanDataType
,
ck_tile
::
null_type
>
;
...
@@ -73,13 +112,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -73,13 +112,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
n
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host
({
n
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host_dev
({
n
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
...
@@ -87,22 +136,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -87,22 +136,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_scale_buf
(
x_scale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_residual_buf
(
x_residual_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
x_scale_buf
.
ToDevice
(
x_scale_host
.
data
());
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_o
)
{
base_str
+=
"|"
+
prec_o
;
}
if
(
fused_quant
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_sy
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_fwd_traits
traits
{
data_type
,
SaveMeanVar
};
layernorm2d_fwd_traits
traits
{
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
nullptr
,
fused_add
==
1
?
y_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
fused_quant
!=
0
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
nullptr
,
// p_mean, unsupported yet
nullptr
,
// p_invStd, unsupported yet
epsilon
,
epsilon
,
m
,
m
,
n
,
n
,
...
@@ -111,6 +187,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -111,6 +187,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float
ave_time
=
layernorm2d_fwd
(
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
{
std
::
cout
<<
" not supported!"
<<
std
::
endl
<<
std
::
flush
;
return
false
;
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
...
@@ -122,6 +204,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -122,6 +204,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
// reference
// reference
if
(
fused_add
!=
0
)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to x_host for simplcity here...
std
::
transform
(
x_host
.
mData
.
cbegin
(),
x_host
.
mData
.
cend
(),
x_residual_host
.
mData
.
cbegin
(),
x_host
.
mData
.
begin
(),
[](
auto
x_
,
auto
r_
)
{
auto
o_
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_
)
+
ck_tile
::
type_convert
<
ComputeDataType
>
(
r_
);
return
ck_tile
::
type_convert
<
XDataType
>
(
o_
);
});
}
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
...
@@ -131,13 +228,83 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -131,13 +228,83 @@ bool run(const ck_tile::ArgParser& arg_parser)
InvStdDataType
>
(
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
if
(
fused_quant
!=
0
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
if
(
fused_quant
==
1
)
{
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
// input smooth outlier
acc_
(
m_
,
n_
)
=
acc_
(
m_
,
n_
)
*
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_scale_host
(
n_
));
}
}
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
o_
(
m_
,
n_
)
=
ck_tile
::
type_convert
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
};
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
,
dquant_functor
);
}
else
{
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
}
y_buf
.
FromDevice
(
y_host_dev
.
data
());
y_buf
.
FromDevice
(
y_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host_dev
({
m
,
n
},
{
stride
,
1
});
if
(
fused_add
==
1
)
{
y_residual_buf
.
FromDevice
(
y_residual_host_dev
.
data
());
}
auto
[
rtol
,
atol
]
=
get_elimit
<
InDataType
>
();
if
(
stride
==
n
)
if
(
stride
==
n
)
{
{
pass
=
ck_tile
::
check_err
(
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
y_host_dev
,
y_host_ref
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
if
(
fused_add
==
1
)
{
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev
,
x_host
,
std
::
string
(
"ADD Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
else
else
{
{
...
@@ -153,8 +320,31 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -153,8 +320,31 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
string
(
"] Error: Incorrect results!"
),
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
rtol
,
atol
);
atol
);
if
(
fused_add
==
1
)
{
std
::
vector
<
YResidualDataType
>
y_residual_host_dev_row
(
y_residual_host_dev
.
begin
()
+
i_r
*
stride
,
y_residual_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
YResidualDataType
>
y_residual_host_ref_row
(
x_host
.
begin
()
+
i_r
*
stride
,
x_host
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
y_residual_host_dev_row
,
y_residual_host_ref_row
,
std
::
string
(
"ADD["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
}
}
if
(
fused_quant
==
1
)
{
y_scale_buf
.
FromDevice
(
y_scale_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
y_scale_host_dev
,
y_scale_host_ref
,
std
::
string
(
"SCALE Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
...
@@ -168,23 +358,56 @@ int main(int argc, char* argv[])
...
@@ -168,23 +358,56 @@ int main(int argc, char* argv[])
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
if
(
data_type
==
"fp16"
&&
save_mv
)
std
::
string
prec_sx
=
arg_parser
.
get_str
(
"prec_sx"
);
std
::
string
prec_sy
=
arg_parser
.
get_str
(
"prec_sy"
);
if
(
prec_o
==
"auto"
)
{
prec_o
=
prec_i
;
}
if
(
prec_sx
==
"auto"
)
{
{
re
turn
run
<
ck_tile
::
half_t
,
true
>
(
arg_parser
)
?
0
:
-
2
;
p
re
c_sx
=
"fp32"
;
}
}
else
if
(
data_type
==
"fp16"
&&
!
save_mv
)
if
(
prec_sy
==
"auto"
)
{
{
re
turn
run
<
ck_tile
::
half_t
,
false
>
(
arg_parser
)
?
0
:
-
2
;
p
re
c_sy
=
"fp32"
;
}
}
else
if
(
data_type
==
"bf16"
&&
save_mv
)
int
save_mv
=
arg_parser
.
get_int
(
"save_mv"
);
// no dynamic quant case
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
// dynamic quant case, only in inference
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
{
return
run
<
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
,
float
,
float
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16"
&&
!
save_mv
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
&&
prec_sx
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_mv
)
{
{
return
run
<
ck_tile
::
bf16_t
,
tru
e
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
fals
e
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
f20e48f1
...
@@ -8,31 +8,35 @@
...
@@ -8,31 +8,35 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
#include <string>
template
<
typename
DataType
>
template
<
typename
InType
,
typename
OutType
,
typename
XScaleDataType_
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
;
struct
LayerNormTypeConfig
;
template
<
>
template
<
typename
OutType
,
typename
XScaleDataType_
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
,
OutType
,
XScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
half_t
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
XScaleDataType
=
XScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
template
<
>
template
<
typename
OutType
,
typename
XScaleDataType_
,
typename
YScaleDataType_
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
,
OutType
,
XScaleDataType_
,
YScaleDataType_
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
OutType
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
InvStdDataType
=
ck_tile
::
bf16_t
;
using
InvStdDataType
=
ck_tile
::
bf16_t
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
using
XScaleDataType
=
XScaleDataType_
;
using
YScaleDataType
=
YScaleDataType_
;
};
};
// runtime args
// runtime args
...
@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
...
@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
{
{
};
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
struct
layernorm2d_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Layernorm2dShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_fwd_args
a
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
layernorm2d_fwd_traits
struct
layernorm2d_fwd_traits
{
{
std
::
string
data_type
;
std
::
string
prec_i
;
// input precision
bool
save_mean_var
;
std
::
string
prec_o
;
// output precision
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std
::
string
prec_sx
;
// x-scale, used for [1*N] input smooth quant
std
::
string
prec_sy
;
// y-scale, used for [M*1] output for next layer
bool
save_mean_var
;
//
int
fused_add
;
// 0:no-add, 1:pre-add-store, 2:pre-add
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
};
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/02_layernorm2d/misc/dquant.png
0 → 100644
View file @
f20e48f1
36 KB
example/ck_tile/02_layernorm2d/misc/pnorm.png
0 → 100644
View file @
f20e48f1
31.4 KB
example/ck_tile/02_layernorm2d/script/perf_test.sh
View file @
f20e48f1
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
# run from top of ck folder
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
EXE
=
build/bin/tile_example_layernorm2d_fwd
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec_i
=
bf16
-repeat
=
1000
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec_i
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
\ No newline at end of file
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
\ No newline at end of file
example/ck_tile/02_layernorm2d/script/smoke_test.sh
View file @
f20e48f1
#!/bin/sh
#!/bin/sh
# call from top of CK folder
EXE
=
"
$(
find
.
-name
tile_example_layernorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
./build/bin/tile_example_layernorm2d_fwd
for
fquant
in
""
"-fquant=1 -prec_o=int8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
for
fadd
in
"0"
"1"
;
do
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done
example/ck_tile/03_gemm/CMakeLists.txt
View file @
f20e48f1
set
(
CMAKE_BUILD_TYPE Debug
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp
)
\ No newline at end of file
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <cstring>
#include <cstring>
...
@@ -10,51 +9,48 @@
...
@@ -10,51 +9,48 @@
#include <string>
#include <string>
#include <tuple>
#include <tuple>
auto
create_args
(
int
argc
,
char
*
argv
[])
#include "ck_tile/ops/epilogue.hpp"
{
#include "ck_tile/ops/gemm.hpp"
ck_tile
::
ArgParser
arg_parser
;
#include "ck_tile/host.hpp"
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
#include "gemm_basic.hpp"
.
insert
(
"m"
,
"1024"
,
"m dimension"
)
.
insert
(
"n"
,
"2048"
,
"n dimension"
)
.
insert
(
"k"
,
"64"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"e"
,
"1e-5"
,
"Absolute error tolerance"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
LayoutA
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
typename
LayoutB
,
typename
LayoutC
,
typename
PipelineProblem
,
typename
GemmPipeline
,
typename
GemmShape
>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
bool
kTilePermute
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
// layout.
constexpr
bool
CShuffleEpilogue
=
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
CShuffleEpilogue
,
...
@@ -70,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -70,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner
::
kN
>>
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadA
,
kPadB
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
Codegen
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_b
,
args
.
p_c
,
args
.
p_c
,
args
.
epsilon
,
args
.
M
,
args
.
M
,
args
.
N
,
args
.
N
,
args
.
K
,
args
.
K
,
...
@@ -88,299 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -88,299 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
if
(
s
.
log_level_
>
0
)
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
template
<
typename
DataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
PipelineProblem
,
typename
GemmPipeline
,
typename
GemmShape
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_buf
,
ck_tile
::
DeviceMem
&
b_buf
,
ck_tile
::
DeviceMem
&
c_buf
,
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
!=
DataTypeTraits
<
DataType
>::
name
)
{
std
::
cerr
<<
"Data type mismatch: expected "
<<
DataTypeTraits
<
DataType
>::
name
<<
", got "
<<
data_type
<<
std
::
endl
;
return
-
1
;
// Or handle the error appropriately
}
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
batch_size
=
arg_parser
.
get_int
(
"b"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
gemm_basic_args
args
;
args
.
p_a
=
a_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_buf
.
GetDeviceBuffer
();
args
.
epsilon
=
epsilon
;
args
.
kbatch
=
batch_size
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if
(
stride_a
!=
0
)
{
args
.
stride_A
=
stride_a
;
}
else
{
args
.
stride_A
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutA
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
K
;
}
}();
}
if
(
stride_b
!=
0
)
{
args
.
stride_B
=
stride_b
;
}
else
{
{
args
.
stride_B
=
[
&
]()
{
std
::
cout
<<
"Launching kernel with args:"
if
constexpr
(
std
::
is_same_v
<
LayoutB
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
{
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
return
N
;
<<
std
::
endl
;
}
else
{
return
K
;
}
}();
}
}
if
(
stride_c
!=
0
)
float
ave_time
=
ck_tile
::
launch_kernel
(
{
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
args
.
stride_C
=
stride_c
;
}
else
{
args
.
stride_C
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
LayoutC
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
M
;
}
else
{
return
N
;
}
}();
}
float
ave_time
=
gemm_calc
<
LayoutA
,
LayoutB
,
LayoutC
,
PipelineProblem
,
GemmPipeline
,
GemmShape
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
N
*
K
+
sizeof
(
CDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"The overall perfomance of the GEMM with "
<<
"["
<<
data_type
<<
"]"
<<
"batch size: "
<<
batch_size
<<
". m:"
<<
M
<<
", n:"
<<
N
<<
", k:"
<<
K
<<
" is:
\n
"
;
std
::
cout
<<
"Running time: "
<<
ave_time
<<
"ms, Throughput "
<<
gb_per_sec
<<
"GB/s
\n
"
<<
std
::
flush
;
return
ave_time
;
return
ave_time
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
#include "run_gemm_example.inc"
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
K
=
arg_parser
.
get_int
(
"k"
);
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using
matrix_a_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
matrix_b_layout
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
matrix_c_layout
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// host verify
std
::
vector
<
int
>
a_dimensions
=
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
K
}
:
std
::
vector
<
int
>
{
K
,
M
};
std
::
vector
<
int
>
b_dimensions
=
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
std
::
vector
<
int
>
{
N
,
K
}
:
std
::
vector
<
int
>
{
K
,
N
};
std
::
vector
<
int
>
c_dimensions
=
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
?
std
::
vector
<
int
>
{
M
,
N
}
:
std
::
vector
<
int
>
{
N
,
M
};
ck_tile
::
HostTensor
<
ADataType
>
a_host
(
a_dimensions
);
ck_tile
::
HostTensor
<
BDataType
>
b_host
(
b_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_ref
(
c_dimensions
);
ck_tile
::
HostTensor
<
CDataType
>
c_host_dev
(
c_dimensions
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_buf
(
b_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_buf
(
c_host_dev
.
get_element_space_size_in_bytes
());
a_buf
.
ToDevice
(
a_host
.
data
());
b_buf
.
ToDevice
(
b_host
.
data
());
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
<
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
invoke_gemm
<
ck_tile
::
half_t
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
,
CodegenPipelineProblem
,
CodegenGemmPipeline
,
CodegenGemmShape
>
(
a_buf
,
b_buf
,
c_buf
,
arg_parser
);
c_buf
.
FromDevice
(
c_host_dev
.
data
());
bool
pass_cpu
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
1
)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_host
,
b_host
,
c_host_ref
);
pass_cpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_ref
);
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass_cpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
bool
pass_gpu
=
true
;
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
ck_tile
::
index_t
stride_a
=
arg_parser
.
get_int
(
"stride_a"
);
ck_tile
::
index_t
stride_b
=
arg_parser
.
get_int
(
"stride_b"
);
ck_tile
::
index_t
stride_c
=
arg_parser
.
get_int
(
"stride_c"
);
if
(
stride_a
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_a_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_a
=
M
;
}
else
{
stride_a
=
K
;
}
}
if
(
stride_b
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_b_layout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
stride_b
=
N
;
}
else
{
stride_b
=
K
;
}
}
if
(
stride_c
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
matrix_c_layout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
stride_c
=
M
;
}
else
{
stride_c
=
N
;
}
}
ck_tile
::
HostTensor
<
CDataType
>
c_host_gpu_ref
(
c_dimensions
);
ck_tile
::
DeviceMem
c_gpu_buf
(
c_host_gpu_ref
.
get_element_space_size_in_bytes
());
ck_tile
::
reference_gemm_gpu
<
ADataType
,
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
BDataType
,
AccDataType
,
CDataType
,
matrix_a_layout
,
matrix_b_layout
,
matrix_c_layout
>
(
a_buf
,
b_buf
,
c_gpu_buf
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_buf
.
FromDevice
(
c_host_gpu_ref
.
data
());
pass_gpu
=
ck_tile
::
check_err
(
c_host_dev
,
c_host_gpu_ref
);
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass_gpu
?
"correct"
:
"fail"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
return
!
pass_gpu
;
}
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
f20e48f1
...
@@ -4,12 +4,10 @@
...
@@ -4,12 +4,10 @@
#pragma once
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
struct
GemmBasicTypeConfig
;
...
@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
...
@@ -20,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
using
ADataType
=
ck_tile
::
half_t
;
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
// type convert
using
CDataType
=
ck_tile
::
half_t
;
// ToDo: Add more bias config to support different categories of GEMM.
// ToDo: Add more bias config to support different categories of GEMM.
};
};
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
...
@@ -58,7 +56,6 @@ struct gemm_basic_args
const
void
*
p_a
;
const
void
*
p_a
;
const
void
*
p_b
;
const
void
*
p_b
;
void
*
p_c
;
void
*
p_c
;
float
epsilon
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
N
;
...
@@ -68,5 +65,28 @@ struct gemm_basic_args
...
@@ -68,5 +65,28 @@ struct gemm_basic_args
ck_tile
::
index_t
stride_C
;
ck_tile
::
index_t
stride_C
;
};
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"3840"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
// host API
// host API
float
gemm_calc
(
gemm_basic_args
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm_calc
(
gemm_basic_args
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/03_gemm/gemm_mem_pipeline.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// ToDo: This will be modified by the codegen code later.
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadA
=
true
;
constexpr
bool
kPadB
=
true
;
constexpr
bool
kPadC
=
true
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
float
ave_time
{
0
};
const
auto
Run
=
[
&
](
const
auto
has_hot_loop_
,
const
auto
tail_number_
)
{
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
};
if
(
has_hot_loop
)
{
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
else
{
// Tail number always Full - #PrefetchStages
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
false
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"When there's no hot loop, this tail number
\"
"
<<
tail_num
<<
"
\"
is not supported! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
ave_time
;
}
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/03_gemm/run_gemm_example.inc
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
c_m_n_dev_buf
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
kbatch
,
int
n_warmup
,
int
n_repeat
)
{
gemm_basic_args
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
std
::
string
op_name
{
"Gemm
{
MemBoundPipeline}"
}
;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "
Run
" << op_name << "
kernel
with
M
=
" << M << "
N
=
" << N << "
K
=
" << K
<< "
StrideA
=
" << stride_A << "
StrideB
=
" << stride_B << "
StrideC
=
" << stride_C
<< "
:
" << ave_time << "
ms
,
" << tflops << "
TFlops
,
" << gb_per_sec << "
GB
/
s
,
"
<< std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("
m
");
ck_tile::index_t N = arg_parser.get_int("
n
");
ck_tile::index_t K = arg_parser.get_int("
k
");
ck_tile::index_t stride_A = arg_parser.get_int("
stride_a
");
ck_tile::index_t stride_B = arg_parser.get_int("
stride_b
");
ck_tile::index_t stride_C = arg_parser.get_int("
stride_c
");
ck_tile::index_t batch_size = arg_parser.get_int("
b
");
int n_warmup = arg_parser.get_int("
warmup
");
int n_repeat = arg_parser.get_int("
repeat
");
using namespace ck_tile::literals;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_size,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("
v
") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
std::cout << "
The
CPU
veification
result
is
:
" << (pass ? "
correct
" : "
fail
") << std::endl;
}
else if(arg_parser.get_int("
v
") == 2)
{
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
std::cout << "
The
GPU
veification
result
is
:
" << (pass ? "
correct
" : "
fail
") << std::endl;
}
return pass;
}
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("
a_layout
");
std::string b_layout = arg_parser.get_str("
b_layout
");
if(a_layout == "
R
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "
R
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
C
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "
C
" && b_layout == "
R
")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
else
{
throw std::runtime_error("
Unsupported
data
layout
configuration
for
A
,
B
and
C
tensors
!
");
}
}
example/ck_tile/05_reduce/reduce.cpp
View file @
f20e48f1
...
@@ -19,9 +19,9 @@ auto create_args(int argc, char* argv[])
...
@@ -19,9 +19,9 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
using
A
DataType
=
DataType
;
using
X
DataType
=
DataType
;
using
Acc
DataType
=
float
;
using
Compute
DataType
=
float
;
using
B
DataType
=
DataType
;
using
Y
DataType
=
DataType
;
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
...
@@ -29,35 +29,39 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -29,35 +29,39 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
ck_tile
::
HostTensor
<
A
DataType
>
a
_host
({
m
,
n
});
ck_tile
::
HostTensor
<
X
DataType
>
x
_host
({
m
,
n
});
ck_tile
::
HostTensor
<
B
DataType
>
b
_host_ref
({
m
});
ck_tile
::
HostTensor
<
Y
DataType
>
y
_host_ref
({
m
});
ck_tile
::
HostTensor
<
B
DataType
>
b
_host_dev
({
m
});
ck_tile
::
HostTensor
<
Y
DataType
>
y
_host_dev
({
m
});
ck_tile
::
FillUniformDistribution
<
A
DataType
>
{
-
5.
f
,
5.
f
}(
a
_host
);
ck_tile
::
FillUniformDistribution
<
X
DataType
>
{
-
5.
f
,
5.
f
}(
x
_host
);
ck_tile
::
DeviceMem
a
_buf
(
a
_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x
_buf
(
x
_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b
_buf
(
b
_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y
_buf
(
y
_host_dev
.
get_element_space_size_in_bytes
());
a
_buf
.
ToDevice
(
a
_host
.
data
());
x
_buf
.
ToDevice
(
x
_host
.
data
());
using
ReduceOp
=
ck_tile
::
ReduceOp
::
Add
;
using
BlockWarps
=
ck_tile
::
sequence
<
4
,
1
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
4
,
1
>
;
using
BlockTile
=
ck_tile
::
sequence
<
128
,
128
>
;
using
BlockTile
=
ck_tile
::
sequence
<
128
,
128
>
;
using
WarpTile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
WarpTile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
ThreadTile
=
ck_tile
::
sequence
<
8
,
8
>
;
using
Vector
=
ck_tile
::
sequence
<
8
,
8
>
;
constexpr
ck_tile
::
index_t
kBlockSize
=
256
;
// cross warp-reduce
// using BlockWarps = ck_tile::sequence<2, 2>;
// using BlockTile = ck_tile::sequence<2, 1024>;
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;
constexpr
ck_tile
::
index_t
kBlockSize
=
512
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
ck_tile
::
index_t
kGridSize
=
(
m
/
BlockTile
::
at
(
ck_tile
::
number
<
0
>
{}));
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
using
Kernel
=
ck_tile
::
Reduce
<
ADataType
,
using
Shape
=
ck_tile
::
Reduce2dShape
<
BlockWarps
,
BlockTile
,
WarpTile
,
Vector
>
;
AccDataType
,
using
Porblem
=
BDataType
,
ck_tile
::
Reduce2dProblem
<
XDataType
,
ComputeDataType
,
YDataType
,
Shape
,
ReduceOp
>
;
kBlockSize
,
BlockWarps
,
using
Kernel
=
ck_tile
::
Reduce
<
Porblem
>
;
BlockTile
,
WarpTile
,
ThreadTile
>
;
float
ave_time
=
launch_kernel
(
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
},
float
ave_time
=
launch_kernel
(
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
},
ck_tile
::
make_kernel
<
kBlockSize
,
kBlockPerCu
>
(
ck_tile
::
make_kernel
<
kBlockSize
,
kBlockPerCu
>
(
...
@@ -65,12 +69,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -65,12 +69,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
kGridSize
,
kGridSize
,
kBlockSize
,
kBlockSize
,
0
,
0
,
static_cast
<
A
DataType
*>
(
a
_buf
.
GetDeviceBuffer
()),
static_cast
<
X
DataType
*>
(
x
_buf
.
GetDeviceBuffer
()),
static_cast
<
B
DataType
*>
(
b
_buf
.
GetDeviceBuffer
()),
static_cast
<
Y
DataType
*>
(
y
_buf
.
GetDeviceBuffer
()),
m
,
m
,
n
));
n
));
std
::
size_t
num_btype
=
sizeof
(
A
DataType
)
*
m
*
n
+
sizeof
(
B
DataType
)
*
m
;
std
::
size_t
num_btype
=
sizeof
(
X
DataType
)
*
m
*
n
+
sizeof
(
Y
DataType
)
*
m
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
...
@@ -81,9 +85,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -81,9 +85,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
if
(
do_validation
)
{
{
// reference
// reference
ck_tile
::
reference_reduce
<
ADataType
,
AccDataType
,
BDataType
>
(
a_host
,
b_host_ref
);
ck_tile
::
reference_reduce
<
XDataType
,
ComputeDataType
,
YDataType
>
(
b_buf
.
FromDevice
(
b_host_dev
.
mData
.
data
());
x_host
,
y_host_ref
,
ReduceOp
{});
pass
=
ck_tile
::
check_err
(
b_host_dev
,
b_host_ref
);
y_buf
.
FromDevice
(
y_host_dev
.
mData
.
data
());
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
);
std
::
cout
<<
"valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
"valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
...
@@ -103,8 +108,8 @@ int main(int argc, char* argv[])
...
@@ -103,8 +108,8 @@ int main(int argc, char* argv[])
{
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
}
if
(
data_type
==
"bf16"
)
// else
if(data_type == "bf16")
{
//
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
//
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
//
}
}
}
example/ck_tile/05_reduce/reduce.hpp
View file @
f20e48f1
...
@@ -5,20 +5,16 @@
...
@@ -5,20 +5,16 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
ADataType
,
template
<
typename
BlockWarps
,
// num warps along seq<M, N>
typename
AccDataType
,
typename
BDataType
,
index_t
kBlockSize
,
typename
BlockWarps
,
// num warps along seq<M, N>
typename
BlockTile
,
// block size, seq<M, N>
typename
BlockTile
,
// block size, seq<M, N>
typename
WarpTile
,
// warp size, seq<M, N>
typename
WarpTile
,
// warp size, seq<M, N>
typename
ThreadTile
>
// contiguous pixels(vector size) along seq<M, N>
typename
Vector
>
// contiguous pixels(vector size) along seq<M, N>
struct
Reduce
struct
Reduce
2dShape
{
{
static
constexpr
index_t
Block_M
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_M
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_N
=
BlockTile
::
at
(
number
<
1
>
{});
...
@@ -26,93 +22,143 @@ struct Reduce
...
@@ -26,93 +22,143 @@ struct Reduce
static
constexpr
index_t
Warp_M
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_M
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Thread_M
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_M
=
Vector
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Thread_N
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Vector_N
=
Vector
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_M
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_M
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Thread
_M
;
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector
_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Thread
_N
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector
_N
;
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
__device__
static
constexpr
auto
MakeABlockTileDistribution
()
static
constexpr
index_t
BlockSize
=
{
warpSize
*
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
return
make_static_tile_distribution
(
};
tile_distribution_encoding
<
sequence
<>
,
template
<
typename
XDataType_
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
,
ThreadPerWarp_M
,
Thread_M
>
,
typename
ComputeDataType_
,
sequence
<
Repeat_N
,
WarpPerBlock_N
,
ThreadPerWarp_N
,
Thread_N
>>
,
typename
YDataType_
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
typename
BlockShape_
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
typename
ReduceOp_
>
sequence
<
1
,
1
,
2
,
2
>
,
struct
Reduce2dProblem
sequence
<
0
,
3
,
0
,
3
>>
{});
{
}
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
ReduceOp
=
ReduceOp_
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
};
template
<
typename
Problem_
,
typename
Policy_
=
BlockReduce2dDefaultPolicy
>
struct
Reduce
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
__device__
void
operator
()(
const
ADataType
*
p_a
,
BDataType
*
p_b
,
index_t
M
,
index_t
N
)
const
#if 0
CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N)
const
{
{
const
auto
a_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
using S = typename Problem::BlockShape;
p_a
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
Thread_N
>
{},
number
<
1
>
{});
const
auto
iM
=
get_block_id
()
*
Block_M
;
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
// A window
const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
auto
a_block_window
=
make_tile_window
(
a_m_n
,
p_y, make_tuple(M), number<1>{});
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
},
const auto iM = get_block_id() * S::Block_M;
MakeABlockTileDistribution
());
auto x_window = make_tile_window(x_m_n,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };
const
A
DataType
reduce_init_value
=
0
;
const
X
DataType reduce_init_value = 0;
constexpr auto reduce_dims = sequence<1>{};
constexpr auto reduce_dims = sequence<1>{};
// Acc tile
auto y_compute = decltype(block_tile_reduce<ComputeDataType>(
// TODO: support cross warp reduction
load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){};
auto
acc_block_tensor
=
decltype
(
block_tile_reduce
<
AccDataType
>
(
load_tile
(
a_block_window
),
reduce_dims
,
f_reduce
,
reduce_init_value
)){};
// init Acc tile
set_tile(y_compute, reduce_init_value);
tile_elementwise_inout
(
[
&
](
auto
&
acc
)
{
acc
=
type_convert
<
AccDataType
>
(
reduce_init_value
);
},
acc_block_tensor
);
// loop
index_t num_n_tile_iteration =
index_t
iN
=
0
;
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N))
;
do
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
{
const
auto
a_block_tensor
=
load_tile
(
a_block_window
);
const auto x = load_tile(x_window);
block_tile_reduce(y_compute, x, reduce_dims, f_reduce);
move_tile_window(x_window, {0, S::Block_N});
}
// FIXME: support cross warp reduction
block_tile_reduce_sync(y_compute, f_reduce);
block_tile_reduce
(
acc_block_tensor
,
a_block_tensor
,
reduce_dims
,
f_reduce
);
store_tile(y_window, cast_tile<YDataType>(y_compute));
}
#else
CK_TILE_DEVICE
void
operator
()(
const
XDataType
*
p_x
,
YDataType
*
p_y
,
index_t
M
,
index_t
N
)
const
{
using
S
=
typename
Problem
::
BlockShape
;
move_tile_window
(
a_block_window
,
{
0
,
Block_N
});
const
auto
x_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
Vector_N
>
{},
number
<
1
>
{});
iN
+=
Block_N
;
const
auto
y_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_y
,
make_tuple
(
M
),
number
<
1
>
{});
}
while
(
iN
<
N
)
;
const
auto
iM
=
get_block_id
()
*
S
::
Block_M
;
// FIXME: support cross warp reduction
auto
x_window
=
make_tile_window
(
x_m_n
,
block_tile_reduce_sync
(
acc_block_tensor
,
f_reduce
);
make_tuple
(
number
<
S
::
Block_M
>
{},
number
<
S
::
Block_N
>
{}),
{
iM
,
0
},
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// convert acc_block_tensor to b_block_tensor
auto
y_window
=
make_tile_window
(
y_m
,
make_tuple
(
number
<
S
::
Block_M
>
{}),
{
iM
});
const
auto
b_block_tensor
=
tile_elementwise_in
(
[](
const
auto
&
acc
)
{
return
type_convert
<
BDataType
>
(
acc
);
},
acc_block_tensor
);
// B
__shared__
char
smem
[
Policy
::
template
GetSmemSize
<
Problem
>()];
const
auto
b_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_b
,
make_tuple
(
M
),
number
<
32
>
{});
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
S
::
Block_N
));
auto
reduce_func
=
typename
Problem
::
ReduceOp
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
auto
y_compute
=
block_reduce2d
.
template
MakeYBlockTile
<
XTensorType
>();
set_tile
(
y_compute
,
reduce_func
.
template
GetIdentityValue
<
ComputeDataType
>());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_reduce2d
(
x
,
y_compute
,
reduce_func
);
move_tile_window
(
x_window
,
{
0
,
S
::
Block_N
});
}
// B window
block_reduce2d_sync
(
y_compute
,
reduce_func
);
auto
b_block_window
=
make_tile_window
(
b_m
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
}
);
block_reduce2d_cross_warp_sync
(
y_compute
,
smem
,
reduce_func
);
// store B tile
store_tile
(
y_window
,
cast_tile
<
YDataType
>
(
y_compute
));
store_tile
(
b_block_window
,
b_block_tensor
);
}
}
#endif
};
};
}
// namespace ck_tile
}
// namespace ck_tile
example/ck_tile/06_permute/CMakeLists.txt
0 → 100644
View file @
f20e48f1
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
tile_example_permute EXCLUDE_FROM_ALL permute.cpp
)
if
(
NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL
)
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
set
(
PERMUTE_USE_ALTERNATIVE_IMPL true
)
endif
()
if
(
PERMUTE_USE_ALTERNATIVE_IMPL
)
target_compile_options
(
tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL
)
target_sources
(
tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp
)
endif
()
# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker)
Prev
1
2
3
4
5
6
7
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment