Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
da7ef7e8
Unverified
Commit
da7ef7e8
authored
Jul 08, 2024
by
Jun Liu
Committed by
GitHub
Jul 08, 2024
Browse files
Merge pull request #88 from ROCm/merge_from_public
Merge from public
parents
9f2a6d43
dcd3d21a
Changes
379
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
724 additions
and
140 deletions
+724
-140
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+13
-10
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+96
-6
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+4
-0
example/ck_tile/02_layernorm2d/README.md
example/ck_tile/02_layernorm2d/README.md
+22
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+191
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+30
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
+4
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+18
-23
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+11
-12
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+17
-19
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+7
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+13
-16
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+11
-14
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+7
-9
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+10
-13
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+4
-5
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
+220
-0
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
...ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
+44
-0
No files found.
example/ck_tile/01_fmha/script/smoke_test.sh
→
example/ck_tile/01_fmha/script/smoke_test
_fwd
.sh
View file @
da7ef7e8
...
...
@@ -17,17 +17,19 @@ for perm in 0 1 ; do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
done
...
...
@@ -36,6 +38,7 @@ done
done
done
done
done
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
...
...
example/ck_tile/01_fmha/utils.hpp
View file @
da7ef7e8
...
...
@@ -4,12 +4,14 @@
#pragma once
#include <cstdint>
#include <cstdlib>
#include <optional>
#include <ostream>
#include <tuple>
#include <utility>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
...
...
@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlens_sum
,
int32_t
seqlen_avg
,
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
assert
(
0
<
count
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
seqlens_sum
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
seqlen_max
>
0
?
(
seqlen_avg
<
seqlen_max
?
seqlen_avg
:
seqlen_max
)
:
seqlen_avg
);
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
...
...
@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
uniform_int_distribution
<
size_type
>
step_dist
(
1
,
count
-
1
);
auto
next_step
=
std
::
bind
(
step_dist
,
std
::
ref
(
random_engine
));
for
(
unsigned
repeat
=
seqlen
s_sum
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
for
(
unsigned
repeat
=
seqlen
_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is always greater than 0
...
...
@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlen_max
>
0
&&
seqlens
[
to_increase
]
>=
seqlen_max
)
{
continue
;
}
--
seqlens
[
to_decrease
];
++
seqlens
[
to_increase
];
}
...
...
@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlens_sum
,
int32_t
seqlen_avg
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlens_sum
,
seed
));
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_max
,
seed
));
}
/*
* decode the seqlen string from cmdline
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
std
::
tuple
<
std
::
vector
<
ck_tile
::
index_t
>
,
std
::
vector
<
ck_tile
::
index_t
>
,
std
::
vector
<
ck_tile
::
index_t
>>
decode_seqlen
(
mode_enum
mode
,
ck_tile
::
index_t
batch
,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if
(
mode
==
mode_enum
::
batch
)
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
k
<
0
?
q
:
k
);
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
else
{
ck_tile
::
index_t
idx
=
0
;
std
::
string
::
size_type
pos_q
=
0
;
std
::
string
::
size_type
pos_k
=
0
;
std
::
string
::
size_type
pos_kp
=
0
;
std
::
vector
<
ck_tile
::
index_t
>
s_q
;
std
::
vector
<
ck_tile
::
index_t
>
s_k
;
std
::
vector
<
ck_tile
::
index_t
>
s_kpad
;
while
(
true
)
{
auto
found_q
=
q_val
.
find
(
','
,
pos_q
);
auto
found_k
=
k_val
.
find
(
','
,
pos_k
);
auto
found_kp
=
k_pad_val
.
find
(
','
,
pos_kp
);
ck_tile
::
index_t
q
=
_S2I_
(
q_val
.
substr
(
pos_q
,
found_q
==
std
::
string
::
npos
?
found_q
:
found_q
-
pos_q
));
ck_tile
::
index_t
k
=
_S2I_
(
k_val
.
substr
(
pos_k
,
found_k
==
std
::
string
::
npos
?
found_k
:
found_k
-
pos_k
));
ck_tile
::
index_t
kp
=
_S2I_
(
k_pad_val
.
substr
(
pos_kp
,
found_kp
==
std
::
string
::
npos
?
found_kp
:
found_kp
-
pos_kp
));
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
break
;
}
pos_q
=
found_q
+
1
;
pos_k
=
found_k
==
std
::
string
::
npos
?
pos_k
:
found_k
+
1
;
pos_kp
=
found_kp
==
std
::
string
::
npos
?
pos_kp
:
found_kp
+
1
;
}
if
(
idx
<
batch
)
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
s_kpad
.
insert
(
s_kpad
.
end
(),
batch
-
idx
,
s_kpad
.
back
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
#undef _S2I_
}
int
env_get_int
(
const
char
*
var_name
,
int
default_int
)
...
...
@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int)
char
*
v
=
getenv
(
var_name
);
int
r
=
default_int
;
if
(
v
)
r
=
atoi
(
v
);
r
=
std
::
atoi
(
v
);
return
r
;
}
example/ck_tile/02_layernorm2d/CMakeLists.txt
0 → 100644
View file @
da7ef7e8
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp
)
target_compile_options
(
tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD
)
\ No newline at end of file
example/ck_tile/02_layernorm2d/README.md
0 → 100644
View file @
da7ef7e8
# Layernorm2D forward
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
## example
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
\ No newline at end of file
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
0 → 100644
View file @
da7ef7e8
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <cstring>
// Host API implementation
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
t
,
layernorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
thread_tile
=
ck_tile
::
sequence
<
4
,
4
>
;
using
warp_tile
=
ck_tile
::
sequence
<
8
,
128
>
;
using
block_tile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
,
Shape
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
.
p_x
,
a
.
p_gamma
,
a
.
p_beta
,
a
.
p_y
,
a
.
p_mean
,
a
.
p_invStd
,
a
.
epsilon
,
a
.
M
,
a
.
N
);
const
dim3
grids
=
Kernel
::
GridSize
(
a
.
M
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
Shape
::
kMWarpPerBlock
*
Shape
::
kNWarpPerBlock
;
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
0
;
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"m dimension"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
M
,
N
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
N
});
ck_tile
::
HostTensor
<
BetaDataType
>
beta_host
({
N
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
M
,
N
});
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
M
,
N
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_dev
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_dev
({
M
});
#endif
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
5.
f
,
5.
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
5.
f
,
5.
f
}(
beta_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
DeviceMem
mean_buf
(
mean_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host_dev
.
get_element_space_size_in_bytes
());
#endif
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
layernorm2d_fwd_traits
traits
{
data_type
};
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
epsilon
,
M
,
N
};
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
M
<<
", n:"
<<
N
<<
", "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
// reference
ck_tile
::
reference_layernorm2d_fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
gamma_host
,
beta_host
,
y_host_ref
,
mean_host_ref
,
invStd_host_ref
,
epsilon
);
y_buf
.
FromDevice
(
y_host_dev
.
data
());
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
);
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
FromDevice
(
mean_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
mean_host_dev
,
mean_host_ref
);
invStd_buf
.
FromDevice
(
invStd_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
invStd_host_dev
,
invStd_host_ref
);
#endif
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
return
!
pass
;
}
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
struct
layernorm2d_fwd_traits
{
std
::
string
data_type
;
};
struct
layernorm2d_fwd_args
{
const
void
*
p_x
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_mean
;
void
*
p_invStd
;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
};
// host API
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
,
layernorm2d_fwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/CMakeLists.txt
View file @
da7ef7e8
...
...
@@ -3,3 +3,4 @@ include_directories(AFTER
)
add_subdirectory
(
01_fmha
)
add_subdirectory
(
02_layernorm2d
)
include/ck/ck.hpp
View file @
da7ef7e8
...
...
@@ -151,7 +151,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
1
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
0
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -300,9 +300,9 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
dpp_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
dpp_input_type
>(),
b_thread_vec
.
template
AsType
<
dpp_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
dpp_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
dpp_input_type
>(),
b_thread_vec
.
template
AsType
<
dpp_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -613,7 +613,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -681,7 +681,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -749,10 +749,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -808,10 +807,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -840,10 +838,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -901,10 +898,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -939,10 +935,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -259,7 +259,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -319,10 +319,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -446,12 +445,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -584,7 +583,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -668,7 +667,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
...
...
@@ -303,7 +303,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -374,7 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -428,10 +428,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -480,10 +479,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -646,12 +644,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
...
...
@@ -821,7 +819,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -914,7 +912,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -990,7 +988,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -1066,7 +1064,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -381,7 +381,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -440,10 +440,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
GlobalBufferNum
=
2
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
...
...
@@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
...
...
@@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
...
...
@@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -850,10 +850,9 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
wmma_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -909,10 +908,9 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
wmma_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -537,10 +536,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -970,10 +968,9 @@ struct BlockwiseGemmXdlops_v2
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template
<
typename
ThreadGroup
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
typename
SrcScalarPerVectors
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r3
{
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7r3
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
}
}
template
<
typename
SrcBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
,
thread_scratch_id
);
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
if
constexpr
(
is_detected
<
is_tuple
,
decltype
(
dst_bufs
)
>::
value
)
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
,
thread_scratch_id
);
else
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
),
thread_scratch_id
);
}
}
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
RunRead
(
src_descs
,
src_bufs
);
RunWrite
(
dst_descs
,
dst_bufs
);
}
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
step
)
{
static_for
<
0
,
SrcDescs
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
MoveSrcSliceWindow
(
src_descs
,
i
,
step
);
});
}
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_descs
,
iDst
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
step
)
{
static_for
<
0
,
DstDescs
::
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
MoveDstSliceWindow
(
dst_descs
,
i
,
step
);
});
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7r3
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVectors
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
0 → 100644
View file @
da7ef7e8
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm_Streamk_V2
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
Streamk_sel
,
ck
::
index_t
Grid_size
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
9
10
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment