Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cbf281f0
Commit
cbf281f0
authored
Aug 22, 2023
by
Bartlomiej Wroblewski
Browse files
Merge remote-tracking branch 'origin/develop' into bwroblew/contrib
parents
f3aceeab
d52ec016
Changes
77
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
794 additions
and
254 deletions
+794
-254
CHANGELOG.md
CHANGELOG.md
+2
-2
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
+37
-22
client_example/19_pool_fwd/max_pool2d_fwd.cpp
client_example/19_pool_fwd/max_pool2d_fwd.cpp
+75
-20
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+5
-18
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-0
example/01_gemm/gemm_dl_dpp8_fp16.cpp
example/01_gemm/gemm_dl_dpp8_fp16.cpp
+37
-0
example/13_pool2d_fwd/pool2d_fwd_common.hpp
example/13_pool2d_fwd/pool2d_fwd_common.hpp
+24
-18
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
+32
-26
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
+32
-26
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+55
-44
example/48_pool3d_fwd/pool3d_fwd_fp16.cpp
example/48_pool3d_fwd/pool3d_fwd_fp16.cpp
+39
-18
example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp
+16
-12
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
+21
-15
example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp
+16
-12
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
+16
-12
include/ck/ck.hpp
include/ck/ck.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp
+370
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
+10
-7
No files found.
CHANGELOG.md
View file @
cbf281f0
...
@@ -25,8 +25,8 @@ Full documentation for Composable Kernel is not yet available.
...
@@ -25,8 +25,8 @@ Full documentation for Composable Kernel is not yet available.
-
Added multi-embeddings support (#542).
-
Added multi-embeddings support (#542).
-
Added Navi3x blockwise GEMM and real GEMM support (#541).
-
Added Navi3x blockwise GEMM and real GEMM support (#541).
-
Added Navi grouped ConvBwdWeight support (#505).
-
Added Navi grouped ConvBwdWeight support (#505).
-
Added
pool3d
forward (#
697
).
-
Added
MaxPool, AvgPool
forward (#
815
).
-
Added
m
ax
p
ool backward (#750).
-
Added
M
ax
P
ool backward (#750).
### Changed
### Changed
-
Changed ...
-
Changed ...
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
View file @
cbf281f0
...
@@ -16,6 +16,9 @@ using InDataType = ck::half_t;
...
@@ -16,6 +16,9 @@ using InDataType = ck::half_t;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
constexpr
ck
::
index_t
InOutRank
=
5
;
constexpr
ck
::
index_t
InOutRank
=
5
;
constexpr
ck
::
index_t
WindowRank
=
3
;
constexpr
ck
::
index_t
WindowRank
=
3
;
#if 0
#if 0
...
@@ -44,33 +47,41 @@ struct SimpleDeviceMem
...
@@ -44,33 +47,41 @@ struct SimpleDeviceMem
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
ck
::
index_t
N
=
2
;
ck
::
index_t
N
=
2
;
ck
::
index_t
C
=
32
;
ck
::
index_t
C
=
32
;
ck
::
index_t
Z
=
2
;
ck
::
index_t
Z
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Di
=
30
;
ck
::
index_t
Di
=
30
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
window_stride_d
=
2
;
ck
::
index_t
window_stride_d
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_d
=
1
;
ck
::
index_t
window_dilation_d
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_d
=
1
;
ck
::
index_t
in_left_pad_d
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_d
=
1
;
ck
::
index_t
Do
=
(
Di
+
in_left_pad_d
+
in_right_pad_d
-
Z
)
/
window_stride_d
+
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
ck
::
index_t
Zs
=
(
Z
-
1
)
*
window_dilation_d
+
1
;
const
ck
::
index_t
Ys
=
(
Y
-
1
)
*
window_dilation_h
+
1
;
const
ck
::
index_t
Xs
=
(
X
-
1
)
*
window_dilation_w
+
1
;
ck
::
index_t
Do
=
(
Di
+
in_left_pad_d
+
in_right_pad_d
-
Zs
)
/
window_stride_d
+
1
;
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Ys
)
/
window_stride_h
+
1
;
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
Xs
)
/
window_stride_w
+
1
;
// Pool API only support the order of NCDHW
// Pool API only support the order of NCDHW
std
::
vector
<
ck
::
index_t
>
in_length
=
{
N
,
C
,
Di
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
in_length
=
{
N
,
C
,
Di
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
out_length
=
{
N
,
C
,
Do
,
Ho
,
Wo
};
std
::
vector
<
ck
::
index_t
>
out_length
=
{
N
,
C
,
Do
,
Ho
,
Wo
};
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
=
{
Z
,
Y
,
X
};
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
=
{
Z
,
Y
,
X
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
window_stride_d
,
window_stride_h
,
window_stride_w
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
window_stride_d
,
window_stride_h
,
window_stride_w
};
std
::
vector
<
ck
::
index_t
>
window_dilations
{
window_dilation_d
,
window_dilation_h
,
window_dilation_w
};
std
::
vector
<
ck
::
index_t
>
input_left_pads
=
{
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_left_pads
=
{
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_right_pads
=
{
in_right_pad_d
,
in_right_pad_h
,
in_right_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_right_pads
=
{
in_right_pad_d
,
in_right_pad_h
,
in_right_pad_w
};
...
@@ -90,6 +101,8 @@ int main(int argc, char* argv[])
...
@@ -90,6 +101,8 @@ int main(int argc, char* argv[])
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
IndexDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
ReduceOpId
,
OutputIndex
>
;
OutputIndex
>
;
...
@@ -122,6 +135,7 @@ int main(int argc, char* argv[])
...
@@ -122,6 +135,7 @@ int main(int argc, char* argv[])
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
,
4
});
{
2
,
3
,
4
});
...
@@ -181,6 +195,7 @@ int main(int argc, char* argv[])
...
@@ -181,6 +195,7 @@ int main(int argc, char* argv[])
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
,
4
});
{
2
,
3
,
4
});
...
...
client_example/19_pool_fwd/max_pool2d_fwd.cpp
View file @
cbf281f0
...
@@ -10,14 +10,18 @@
...
@@ -10,14 +10,18 @@
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/pool
2
d_fwd.hpp"
#include "ck/library/tensor_operation_instance/gpu/pool
3
d_fwd.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
IndexDataType
=
int32_t
;
using
IndexDataType
=
int32_t
;
constexpr
ck
::
index_t
InOutRank
=
4
;
// We use pool3d to implement pool2d in this example
constexpr
ck
::
index_t
WindowRank
=
2
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
constexpr
ck
::
index_t
InOutRank
=
5
;
constexpr
ck
::
index_t
WindowRank
=
3
;
#if 1
#if 1
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
bool
OutputIndex
=
true
;
constexpr
bool
OutputIndex
=
true
;
...
@@ -42,31 +46,66 @@ struct SimpleDeviceMem
...
@@ -42,31 +46,66 @@ struct SimpleDeviceMem
void
*
p_mem_
;
void
*
p_mem_
;
};
};
void
TransformPool2dparamToPool3d
(
std
::
vector
<
ck
::
index_t
>&
input_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_stride
,
std
::
vector
<
ck
::
index_t
>&
output_stride
,
std
::
vector
<
ck
::
index_t
>&
indices_stride
,
std
::
vector
<
ck
::
index_t
>&
window_strides
,
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_pads
,
std
::
vector
<
ck
::
index_t
>&
pooling_dims
)
{
// NCHW to NCDHW
input_lengths
.
insert
(
input_lengths
.
begin
()
+
2
,
1
);
output_lengths
.
insert
(
output_lengths
.
begin
()
+
2
,
1
);
input_stride
.
insert
(
input_stride
.
begin
()
+
2
,
0
);
output_stride
.
insert
(
output_stride
.
begin
()
+
2
,
0
);
indices_stride
.
insert
(
indices_stride
.
begin
()
+
2
,
0
);
// YX to ZYX
window_lengths
.
insert
(
window_lengths
.
begin
(),
1
);
window_strides
.
insert
(
window_strides
.
begin
(),
0
);
window_dilations
.
insert
(
window_dilations
.
begin
(),
0
);
input_left_pads
.
insert
(
input_left_pads
.
begin
(),
0
);
input_right_pads
.
insert
(
input_right_pads
.
begin
(),
0
);
pooling_dims
=
{
2
,
3
,
4
};
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
ck
::
index_t
N
=
2
;
ck
::
index_t
N
=
2
;
ck
::
index_t
C
=
32
;
ck
::
index_t
C
=
32
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
ck
::
index_t
Ys
=
(
Y
-
1
)
*
window_dilation_h
+
1
;
const
ck
::
index_t
Xs
=
(
X
-
1
)
*
window_dilation_w
+
1
;
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Ys
)
/
window_stride_h
+
1
;
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
Xs
)
/
window_stride_w
+
1
;
// Pool API only support the order of NCHW
// Pool API only support the order of NCHW
std
::
vector
<
ck
::
index_t
>
in_length
=
{
N
,
C
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
in_length
=
{
N
,
C
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
out_length
=
{
N
,
C
,
Ho
,
Wo
};
std
::
vector
<
ck
::
index_t
>
out_length
=
{
N
,
C
,
Ho
,
Wo
};
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
=
{
Y
,
X
};
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
=
{
Y
,
X
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
window_stride_h
,
window_stride_w
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
window_stride_h
,
window_stride_w
};
std
::
vector
<
ck
::
index_t
>
window_dilations
=
{
window_dilation_h
,
window_dilation_w
};
std
::
vector
<
ck
::
index_t
>
input_left_pads
=
{
in_left_pad_h
,
in_left_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_left_pads
=
{
in_left_pad_h
,
in_left_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_right_pads
=
{
in_right_pad_h
,
in_right_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_right_pads
=
{
in_right_pad_h
,
in_right_pad_w
};
std
::
vector
<
ck
::
index_t
>
pooling_dims
=
{
2
,
3
};
std
::
size_t
in_tensor_size
=
N
*
C
*
Hi
*
Wi
;
std
::
size_t
in_tensor_size
=
N
*
C
*
Hi
*
Wi
;
std
::
size_t
out_tensor_size
=
N
*
C
*
Ho
*
Wo
;
std
::
size_t
out_tensor_size
=
N
*
C
*
Ho
*
Wo
;
...
@@ -75,6 +114,18 @@ int main(int argc, char* argv[])
...
@@ -75,6 +114,18 @@ int main(int argc, char* argv[])
std
::
vector
<
ck
::
index_t
>
in_tensor_stride
=
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
};
std
::
vector
<
ck
::
index_t
>
in_tensor_stride
=
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
};
std
::
vector
<
ck
::
index_t
>
out_tensor_stride
=
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
};
std
::
vector
<
ck
::
index_t
>
out_tensor_stride
=
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
};
TransformPool2dparamToPool3d
(
in_length
,
window_spatial_lengths
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
pooling_dims
);
SimpleDeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_tensor_size
);
SimpleDeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_tensor_size
);
SimpleDeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_tensor_size
);
SimpleDeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_tensor_size
);
SimpleDeviceMem
out_indices_device_buf
(
sizeof
(
IndexDataType
)
*
out_tensor_size
);
SimpleDeviceMem
out_indices_device_buf
(
sizeof
(
IndexDataType
)
*
out_tensor_size
);
...
@@ -84,6 +135,8 @@ int main(int argc, char* argv[])
...
@@ -84,6 +135,8 @@ int main(int argc, char* argv[])
InDataType
,
InDataType
,
OutDataType
,
OutDataType
,
IndexDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
ReduceOpId
,
OutputIndex
>
;
OutputIndex
>
;
...
@@ -116,9 +169,10 @@ int main(int argc, char* argv[])
...
@@ -116,9 +169,10 @@ int main(int argc, char* argv[])
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
}
);
pooling_dims
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
@@ -175,9 +229,10 @@ int main(int argc, char* argv[])
...
@@ -175,9 +229,10 @@ int main(int argc, char* argv[])
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
}
);
pooling_dims
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
docs/sphinx/requirements.in
View file @
cbf281f0
rocm-docs-core
=
=0.
1
0.
3
rocm-docs-core
>
=0.
2
0.
0
sphinxcontrib-bibtex==2.5.0
sphinxcontrib-bibtex==2.5.0
docs/sphinx/requirements.txt
View file @
cbf281f0
...
@@ -38,6 +38,8 @@ docutils==0.16
...
@@ -38,6 +38,8 @@ docutils==0.16
# pydata-sphinx-theme
# pydata-sphinx-theme
# sphinx
# sphinx
# sphinxcontrib-bibtex
# sphinxcontrib-bibtex
fastjsonschema==2.18.0
# via rocm-docs-core
gitdb==4.0.10
gitdb==4.0.10
# via gitpython
# via gitpython
gitpython==3.1.31
gitpython==3.1.31
...
@@ -46,20 +48,12 @@ idna==3.4
...
@@ -46,20 +48,12 @@ idna==3.4
# via requests
# via requests
imagesize==1.4.1
imagesize==1.4.1
# via sphinx
# via sphinx
importlib-metadata==6.0.0
# via
# sphinx
# sphinxcontrib-bibtex
importlib-resources==5.12.0
# via rocm-docs-core
jinja2==3.1.2
jinja2==3.1.2
# via
# via
# myst-parser
# myst-parser
# sphinx
# sphinx
latexcodec==2.0.1
latexcodec==2.0.1
# via pybtex
# via pybtex
linkify-it-py==1.0.3
# via myst-parser
markdown-it-py==2.2.0
markdown-it-py==2.2.0
# via
# via
# mdit-py-plugins
# mdit-py-plugins
...
@@ -70,7 +64,7 @@ mdit-py-plugins==0.3.5
...
@@ -70,7 +64,7 @@ mdit-py-plugins==0.3.5
# via myst-parser
# via myst-parser
mdurl==0.1.2
mdurl==0.1.2
# via markdown-it-py
# via markdown-it-py
myst-parser
[linkify]
==1.0.0
myst-parser==1.0.0
# via rocm-docs-core
# via rocm-docs-core
packaging==23.0
packaging==23.0
# via
# via
...
@@ -99,18 +93,17 @@ pyjwt[crypto]==2.6.0
...
@@ -99,18 +93,17 @@ pyjwt[crypto]==2.6.0
# via pygithub
# via pygithub
pynacl==1.5.0
pynacl==1.5.0
# via pygithub
# via pygithub
pytz==2023.3
# via babel
pyyaml==6.0
pyyaml==6.0
# via
# via
# myst-parser
# myst-parser
# pybtex
# pybtex
# rocm-docs-core
# sphinx-external-toc
# sphinx-external-toc
requests==2.28.2
requests==2.28.2
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core
=
=0.
1
0.
3
rocm-docs-core
>
=0.
2
0.
0
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
@@ -160,13 +153,7 @@ sphinxcontrib-serializinghtml==1.1.5
...
@@ -160,13 +153,7 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx
# via sphinx
typing-extensions==4.5.0
typing-extensions==4.5.0
# via pydata-sphinx-theme
# via pydata-sphinx-theme
uc-micro-py==1.0.1
# via linkify-it-py
urllib3==1.26.15
urllib3==1.26.15
# via requests
# via requests
wrapt==1.15.0
wrapt==1.15.0
# via deprecated
# via deprecated
zipp==3.15.0
# via
# importlib-metadata
# importlib-resources
example/01_gemm/CMakeLists.txt
View file @
cbf281f0
...
@@ -6,6 +6,8 @@ if(DL_KERNELS)
...
@@ -6,6 +6,8 @@ if(DL_KERNELS)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_example_executable
(
example_gemm_dl_fp16 gemm_dl_fp16.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_dependencies
(
example_gemm_dl example_gemm_dl_fp16
)
add_example_executable
(
example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp
)
add_dependencies
(
example_gemm_dl example_gemm_dl_dpp8_fp16
)
endif
()
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
add_example_executable
(
example_gemm_dl_int8 gemm_dl_int8.cpp
)
...
...
example/01_gemm/gemm_dl_dpp8_fp16.cpp
0 → 100644
View file @
cbf281f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmDlDpp8
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
16
,
2
,
1
,
8
,
8
,
S
<
8
,
8
>
,
S
<
4
,
1
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/13_pool2d_fwd/pool2d_fwd_common.hpp
View file @
cbf281f0
...
@@ -39,31 +39,35 @@ bool pool_test(bool do_verification,
...
@@ -39,31 +39,35 @@ bool pool_test(bool do_verification,
ck
::
index_t
Wi
,
ck
::
index_t
Wi
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_dilation_h
,
ck
::
index_t
window_dilation_w
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_w
)
ck
::
index_t
in_right_pad_w
)
{
{
using
DevicePoolFwdInstance
=
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_NHWC_NHWC
<
InDataType
,
InDataType
,
// InDataType
OutDataType
,
OutDataType
,
// OutDataType
IndexDataType
,
IndexDataType
,
// IndexDataType
ComputeDataType
,
ComputeDataType
,
// ComputeDataType
ReduceOpId
,
ReduceOpId
,
OutputIndex
,
OutputIndex
,
64
,
// BlockSize
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
4
>
;
// InSrcOutDstVectorSize
const
ck
::
index_t
Ys
=
(
Y
-
1
)
*
window_dilation_h
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Xs
=
(
X
-
1
)
*
window_dilation_w
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Ys
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
Xs
)
/
window_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
window_dilations
{
window_dilation_h
,
window_dilation_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
...
@@ -123,6 +127,7 @@ bool pool_test(bool do_verification,
...
@@ -123,6 +127,7 @@ bool pool_test(bool do_verification,
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
});
{
2
,
3
});
...
@@ -144,8 +149,8 @@ bool pool_test(bool do_verification,
...
@@ -144,8 +149,8 @@ bool pool_test(bool do_verification,
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
std
::
endl
;
<<
" GB / s "
<<
std
::
endl
;
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -169,6 +174,7 @@ bool pool_test(bool do_verification,
...
@@ -169,6 +174,7 @@ bool pool_test(bool do_verification,
out_indices_n_c_ho_wo_host
,
out_indices_n_c_ho_wo_host
,
window_spatial_lengths
,
window_spatial_lengths
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
...
...
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
View file @
cbf281f0
...
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
...
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
bool
time_kernel
;
bool
time_kernel
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
...
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
}
}
else
if
(
argc
==
1
6
)
else
if
(
argc
==
1
8
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
N
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
12
]);
window_dilation_h
=
std
::
stoi
(
argv
[
12
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
13
]);
window_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx,
Dy, Dx,
LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
...
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
Wi
,
Wi
,
window_stride_h
,
window_stride_h
,
window_stride_w
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_h
,
in_left_pad_w
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_h
,
...
...
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
View file @
cbf281f0
...
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
...
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
bool
time_kernel
;
bool
time_kernel
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
...
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
}
}
else
if
(
argc
==
1
6
)
else
if
(
argc
==
1
8
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
N
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
12
]);
window_dilation_h
=
std
::
stoi
(
argv
[
12
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
13
]);
window_dilation_w
=
std
::
stoi
(
argv
[
13
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
15
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
16
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
17
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx,
Dy, Dx,
LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
...
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
Wi
,
Wi
,
window_stride_h
,
window_stride_h
,
window_stride_w
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_h
,
in_left_pad_w
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_h
,
...
...
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
cbf281f0
...
@@ -18,7 +18,45 @@
...
@@ -18,7 +18,45 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
template
<
typename
InDataType
,
template
<
typename
TensorLayout
>
std
::
vector
<
ck
::
index_t
>
f_tensor_strides_ncdhw
(
ck
::
index_t
N_
,
ck
::
index_t
C_
,
ck
::
index_t
D
,
ck
::
index_t
H
,
ck
::
index_t
W
,
TensorLayout
layout
)
{
using
namespace
ck
::
literals
;
(
void
)
N_
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
};
template
<
typename
TensorLayout
>
HostTensorDescriptor
f_host_tensor_descriptor
(
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
D
,
std
::
size_t
H
,
std
::
size_t
W
,
TensorLayout
layout
)
{
using
namespace
ck
::
literals
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
});
}
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
};
template
<
typename
DevicePoolFwdInstance
,
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
...
@@ -40,6 +78,9 @@ bool pool3d_test(bool do_verification,
...
@@ -40,6 +78,9 @@ bool pool3d_test(bool do_verification,
ck
::
index_t
window_stride_d
,
ck
::
index_t
window_stride_d
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_dilation_d
,
ck
::
index_t
window_dilation_h
,
ck
::
index_t
window_dilation_w
,
ck
::
index_t
in_left_pad_d
,
ck
::
index_t
in_left_pad_d
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_left_pad_w
,
...
@@ -47,53 +88,21 @@ bool pool3d_test(bool do_verification,
...
@@ -47,53 +88,21 @@ bool pool3d_test(bool do_verification,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_w
)
ck
::
index_t
in_right_pad_w
)
{
{
using
DevicePoolFwdInstance
=
const
ck
::
index_t
Zs
=
(
Z
-
1
)
*
window_dilation_d
+
1
;
ck
::
tensor_operation
::
device
::
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
<
const
ck
::
index_t
Ys
=
(
Y
-
1
)
*
window_dilation_h
+
1
;
InDataType
,
// InDataType
const
ck
::
index_t
Xs
=
(
X
-
1
)
*
window_dilation_w
+
1
;
OutDataType
,
// OutDataType
const
ck
::
index_t
Do
=
(
Di
+
in_left_pad_d
+
in_right_pad_d
-
Zs
)
/
window_stride_d
+
1
;
IndexDataType
,
// IndexDataType
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Ys
)
/
window_stride_h
+
1
;
ComputeDataType
,
// ComputeDataType
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
Xs
)
/
window_stride_w
+
1
;
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
4
>
;
// InSrcOutDstVectorSize
const
ck
::
index_t
Do
=
(
Di
+
in_left_pad_d
+
in_right_pad_d
-
Z
)
/
window_stride_d
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Z
,
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Z
,
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_d
,
window_stride_h
,
window_stride_w
};
window_stride_d
,
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
window_dilations
{
window_dilation_d
,
window_dilation_h
,
window_dilation_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_d
,
in_right_pad_h
,
in_right_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_d
,
in_right_pad_h
,
in_right_pad_w
};
// tensor layout
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
D
,
std
::
size_t
H
,
std
::
size_t
W
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
});
}
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
};
Tensor
<
InDataType
>
in_n_c_di_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Di
,
Hi
,
Wi
,
InLayout
{}));
Tensor
<
InDataType
>
in_n_c_di_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Di
,
Hi
,
Wi
,
InLayout
{}));
Tensor
<
OutDataType
>
out_n_c_do_ho_wo_host
(
Tensor
<
OutDataType
>
out_n_c_do_ho_wo_host
(
f_host_tensor_descriptor
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{}));
f_host_tensor_descriptor
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{}));
...
@@ -126,10 +135,11 @@ bool pool3d_test(bool do_verification,
...
@@ -126,10 +135,11 @@ bool pool3d_test(bool do_verification,
{
N
,
C
,
Di
,
Hi
,
Wi
},
{
N
,
C
,
Di
,
Hi
,
Wi
},
{
Z
,
Y
,
X
},
{
Z
,
Y
,
X
},
{
N
,
C
,
Do
,
Ho
,
Wo
},
{
N
,
C
,
Do
,
Ho
,
Wo
},
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
}
,
f_tensor_strides_ncdhw
(
N
,
C
,
Di
,
Hi
,
Wi
,
InLayout
{})
,
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
}
,
f_tensor_strides_ncdhw
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{})
,
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
}
,
f_tensor_strides_ncdhw
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{})
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
,
4
});
{
2
,
3
,
4
});
...
@@ -165,6 +175,7 @@ bool pool3d_test(bool do_verification,
...
@@ -165,6 +175,7 @@ bool pool3d_test(bool do_verification,
out_indices_n_c_do_ho_wo_host
,
out_indices_n_c_do_ho_wo_host
,
window_spatial_lengths
,
window_spatial_lengths
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
...
...
example/48_pool3d_fwd/pool3d_fwd_fp16.cpp
View file @
cbf281f0
...
@@ -27,31 +27,49 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
...
@@ -27,31 +27,49 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
static
constexpr
bool
OutputIndex
=
false
;
static
constexpr
bool
OutputIndex
=
false
;
static
constexpr
bool
PropagateNan
=
false
;
static
constexpr
bool
PropagateNan
=
false
;
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
int
main
()
int
main
()
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
2
;
ck
::
index_t
N
=
2
;
ck
::
index_t
C
=
32
;
ck
::
index_t
C
=
32
;
ck
::
index_t
Z
=
2
;
ck
::
index_t
Z
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Di
=
30
;
ck
::
index_t
Di
=
30
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
window_stride_d
=
2
;
ck
::
index_t
window_stride_d
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_d
=
1
;
ck
::
index_t
window_dilation_d
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_d
=
1
;
ck
::
index_t
in_left_pad_d
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_d
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
bool
pass
=
pool3d_test
<
InDataType
,
bool
pass
=
pool3d_test
<
DevicePoolFwdInstance
,
InDataType
,
OutDataType
,
OutDataType
,
ComputeDataType
,
ComputeDataType
,
IndexDataType
,
IndexDataType
,
...
@@ -72,6 +90,9 @@ int main()
...
@@ -72,6 +90,9 @@ int main()
window_stride_d
,
window_stride_d
,
window_stride_h
,
window_stride_h
,
window_stride_w
,
window_stride_w
,
window_dilation_d
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_d
,
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_h
,
in_left_pad_w
,
in_left_pad_w
,
...
...
example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp
View file @
cbf281f0
...
@@ -24,18 +24,20 @@ int main()
...
@@ -24,18 +24,20 @@ int main()
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
1
;
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
window_stride_h
=
1
;
ck
::
index_t
window_stride_h
=
1
;
ck
::
index_t
window_stride_w
=
1
;
ck
::
index_t
window_stride_w
=
1
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
bool
pass
=
maxpool_bwd_test
<
InDataType
,
bool
pass
=
maxpool_bwd_test
<
InDataType
,
OutDataType
,
OutDataType
,
...
@@ -53,6 +55,8 @@ int main()
...
@@ -53,6 +55,8 @@ int main()
Wi
,
Wi
,
window_stride_h
,
window_stride_h
,
window_stride_w
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_h
,
in_left_pad_w
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_h
,
...
...
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
View file @
cbf281f0
...
@@ -36,6 +36,8 @@ bool maxpool_bwd_test(bool do_verification,
...
@@ -36,6 +36,8 @@ bool maxpool_bwd_test(bool do_verification,
ck
::
index_t
Wi
,
ck
::
index_t
Wi
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
window_dilation_h
,
ck
::
index_t
window_dilation_w
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_h
,
...
@@ -44,28 +46,30 @@ bool maxpool_bwd_test(bool do_verification,
...
@@ -44,28 +46,30 @@ bool maxpool_bwd_test(bool do_verification,
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DevicePoolFwdInstance
=
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_NHWC_NHWC
<
InDataType
,
// InDataType
InDataType
,
// InDataType
OutDataType
,
// OutDataType
OutDataType
,
// OutDataType
IndexDataType
,
// IndexDataType
IndexDataType
,
// IndexDataType
ComputeDataType
,
// ComputeDataType
ComputeDataType
,
// ComputeDataType
ck
::
ReduceTensorOp
::
MAX
,
ck
::
ReduceTensorOp
::
MAX
,
true
,
true
,
// OutputIndex
64
,
// BlockSize
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
1
>
;
// InSrcOutDstVectorSize
using
DeviceMaxPoolBwdInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceMaxPoolBwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceIndexPoolBwdImpl
<
DOutDataType
,
IndexDataType
,
DInDataType
,
4
>
;
DeviceIndexPoolBwdImpl
<
DOutDataType
,
IndexDataType
,
DInDataType
,
4
>
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Ys
=
(
Y
-
1
)
*
window_dilation_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
ck
::
index_t
Xs
=
(
X
-
1
)
*
window_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Ys
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
Xs
)
/
window_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
window_dilations
{
window_dilation_h
,
window_dilation_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
...
@@ -128,6 +132,7 @@ bool maxpool_bwd_test(bool do_verification,
...
@@ -128,6 +132,7 @@ bool maxpool_bwd_test(bool do_verification,
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
{
2
,
3
});
{
2
,
3
});
...
@@ -191,6 +196,7 @@ bool maxpool_bwd_test(bool do_verification,
...
@@ -191,6 +196,7 @@ bool maxpool_bwd_test(bool do_verification,
indices_n_c_ho_wo_host
,
indices_n_c_ho_wo_host
,
window_spatial_lengths
,
window_spatial_lengths
,
window_strides
,
window_strides
,
window_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
ref_pooling_fwd_invoker
.
Run
(
ref_pooling_fwd_argument
);
ref_pooling_fwd_invoker
.
Run
(
ref_pooling_fwd_argument
);
...
...
example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp
View file @
cbf281f0
...
@@ -24,18 +24,20 @@ int main()
...
@@ -24,18 +24,20 @@ int main()
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
1
;
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
window_stride_h
=
1
;
ck
::
index_t
window_stride_h
=
1
;
ck
::
index_t
window_stride_w
=
1
;
ck
::
index_t
window_stride_w
=
1
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
bool
pass
=
maxpool_bwd_test
<
InDataType
,
bool
pass
=
maxpool_bwd_test
<
InDataType
,
OutDataType
,
OutDataType
,
...
@@ -53,6 +55,8 @@ int main()
...
@@ -53,6 +55,8 @@ int main()
Wi
,
Wi
,
window_stride_h
,
window_stride_h
,
window_stride_w
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_h
,
in_left_pad_w
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_h
,
...
...
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
View file @
cbf281f0
...
@@ -24,18 +24,20 @@ int main()
...
@@ -24,18 +24,20 @@ int main()
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
1
;
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Hi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
Wi
=
32
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
window_dilation_h
=
1
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
window_dilation_w
=
1
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
in_right_pad_h
=
0
;
ck
::
index_t
in_right_pad_w
=
0
;
bool
pass
=
maxpool_bwd_test
<
InDataType
,
bool
pass
=
maxpool_bwd_test
<
InDataType
,
OutDataType
,
OutDataType
,
...
@@ -53,6 +55,8 @@ int main()
...
@@ -53,6 +55,8 @@ int main()
Wi
,
Wi
,
window_stride_h
,
window_stride_h
,
window_stride_w
,
window_stride_w
,
window_dilation_h
,
window_dilation_w
,
in_left_pad_h
,
in_left_pad_h
,
in_left_pad_w
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_h
,
...
...
include/ck/ck.hpp
View file @
cbf281f0
...
@@ -125,6 +125,9 @@
...
@@ -125,6 +125,9 @@
// `s_nop`s to avoid hazard
// `s_nop`s to avoid hazard
#define CK_USE_AMD_V_DOT_INLINE_ASM 0
#define CK_USE_AMD_V_DOT_INLINE_ASM 0
// inner product using V_DOT with DPP8 modifiers
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp
0 → 100644
View file @
cbf281f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
namespace
ck
{
/**
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
* the data loaded from LDS to registers.
*
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
* between them in such a way that threads from the same group need the same chunk of either
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
* whole chunk from LDS to its own register space.
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
* of the chunk, and then share that data with other threads from the same lane group.
*
* Assumptions coming from the usage of DPP8:
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
* matrix A or B;
* - based on these values we determine which matrix to share.
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
* - we have to make sure that the data to split is divisible by the number of
* threads in the group.
*
* General algorithm:
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
* A and B are visible to the whole block, C is distributed among each thread
* Assume:
* 1. A:
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
* 2. ABlockBuffer is DynamicBuffer
* 2. B:
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
* 2. BBlockBuffer is DynamicBuffer
* 3. C:
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
* 2. CThreadBuffer is StaticBuffer
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
*/
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_BK0_BM_BK1
,
typename
BBlockDesc_BK0_BN_BK1
,
index_t
BM1PerThreadBM11
,
index_t
BN1PerThreadBN11
,
index_t
BK0PerThread
,
typename
BM10BN10ThreadClusterBM10Xs
,
// Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename
BM10BN10ThreadClusterBN10Xs
,
// Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
BThreadCopyScalarPerVector_BN11
,
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
{
using
AIndex
=
MultiIndex
<
4
>
;
using
BIndex
=
MultiIndex
<
4
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
BK0
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
);
static
constexpr
index_t
BK1
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I2
);
static
constexpr
index_t
BM
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BN
=
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BM100
=
BM10BN10ThreadClusterBM10Xs
{}[
I0
];
static
constexpr
index_t
BN100
=
BM10BN10ThreadClusterBN10Xs
{}[
I0
];
static
constexpr
index_t
BM101
=
BM10BN10ThreadClusterBM10Xs
{}[
I1
];
static
constexpr
index_t
BN101
=
BM10BN10ThreadClusterBN10Xs
{}[
I1
];
static
constexpr
index_t
BM11
=
BM1PerThreadBM11
;
static
constexpr
index_t
BN11
=
BN1PerThreadBN11
;
static
constexpr
index_t
BM1
=
BM100
*
BM101
*
BM11
;
static
constexpr
index_t
BN1
=
BN100
*
BN101
*
BN11
;
static
constexpr
index_t
BM0
=
BM
/
BM1
;
static
constexpr
index_t
BN0
=
BN
/
BN1
;
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
// threads in a lane group need the same chunk of B or A matrices and we can share them using
// DPP.
static_assert
(
BM101
==
dpp8
::
lane_group_size
||
BN101
==
dpp8
::
lane_group_size
);
static
constexpr
bool
ShareB
=
BM101
==
dpp8
::
lane_group_size
?
true
:
false
;
static
constexpr
bool
ShareA
=
!
ShareB
;
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
// respectively) elements, so we split them between threads in lane group so each thread loads
// less data from LDS.
static
constexpr
index_t
BM1PerThread
=
ShareA
?
BM1PerThreadBM11
/
dpp8
::
lane_group_size
:
BM1PerThreadBM11
;
static
constexpr
index_t
BN1PerThread
=
ShareB
?
BN1PerThreadBN11
/
dpp8
::
lane_group_size
:
BN1PerThreadBN11
;
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_block_bk0_bm0_bm1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_bk0_bn0_bn1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
Number
<
BM0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_pass_through_transform
(
Number
<
BN0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
;
}
__host__
__device__
static
constexpr
auto
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
()
{
return
Sequence
<
BM0
,
BM11
,
BN0
,
BN11
>
{};
}
static
constexpr
auto
a_block_desc_bk0_bm0_bm1_bk1_
=
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
ABlockDesc_BK0_BM_BK1
{});
static
constexpr
auto
b_block_desc_bk0_bn0_bn1_bk1_
=
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
BBlockDesc_BK0_BN_BK1
{});
public:
__device__
BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1
()},
b_thread_copy_
{
CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1
()}
{
static_assert
(
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BM
%
BM1
==
0
&&
BN
%
BN1
==
0
,
"wrong!"
);
static_assert
(
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
)
==
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
BM10BN10ThreadClusterBM10Xs
::
Size
()
==
2
&&
BM10BN10ThreadClusterBN10Xs
::
Size
()
==
2
,
"wrong!"
);
}
__device__
static
CIndex
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
index_t
thread_id
)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr
auto
adaptor0
=
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
BM100
,
BN100
,
BM101
,
BN101
)),
make_pass_through_transform
(
BM0
),
make_pass_through_transform
(
BM11
),
make_pass_through_transform
(
BN0
),
make_pass_through_transform
(
BN11
)),
make_tuple
(
Sequence
<
1
,
5
,
2
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
adaptor
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
);
return
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
,
0
,
0
,
0
,
0
));
}
__device__
AIndex
CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1
()
{
const
auto
offsetBM0
=
c_thread_origin_data_idx_
[
I0
];
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
const
auto
offsetBM1
=
ShareA
?
c_thread_origin_data_idx_
[
I1
]
+
dpp8
::
get_thread_idx_in_lane_group
()
*
BM1PerThread
:
c_thread_origin_data_idx_
[
I1
];
return
make_tuple
(
0
,
offsetBM0
,
offsetBM1
,
0
);
}
__device__
BIndex
CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1
()
{
const
auto
offsetBN0
=
c_thread_origin_data_idx_
[
I2
];
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
const
auto
offsetBN1
=
ShareB
?
c_thread_origin_data_idx_
[
I3
]
+
dpp8
::
get_thread_idx_in_lane_group
()
*
BN1PerThread
:
c_thread_origin_data_idx_
[
I3
];
return
make_tuple
(
0
,
offsetBN0
,
offsetBN1
,
0
);
}
template
<
typename
CThreadDesc_BM0_BM11_BN0_BN11
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CThreadDesc_BM0_BM11_BN0_BN11
&
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
CThreadDesc_BM0_BM11_BN0_BN11
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
CThreadDesc_BM0_BM11_BN0_BN11
,
Sequence
<
BK0PerThread
,
BK1
>
,
Sequence
<
1
,
BM1PerThreadBM11
>
,
Sequence
<
1
,
BN1PerThreadBN11
>
,
ShareA
>
{};
static_for
<
0
,
BN0
,
1
>
{}([
&
](
auto
bn0
)
{
static_for
<
0
,
BM0
,
1
>
{}([
&
](
auto
bm0
)
{
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
bm0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
bn0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
bm0
,
I0
,
bn0
,
I0
));
static_for
<
BK0PerThread
,
BK0
,
BK0PerThread
>
{}([
&
](
auto
bk0
)
{
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
bk0
,
bm0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
bk0
,
bn0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
bm0
,
I0
,
bn0
,
I0
));
});
});
});
}
private:
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThread
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThread
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BM1PerThread
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BM1PerThread
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BN1PerThread
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BN1PerThread
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
cbf281f0
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
namespace
ck
{
namespace
ck
{
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are vis
a
ble to the whole block, C is distributed among each thread
// A and B are vis
i
ble to the whole block, C is distributed among each thread
// Assume:
// Assume:
// 1. A:
// 1. A:
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
...
...
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
View file @
cbf281f0
...
@@ -17,6 +17,8 @@ template <index_t InOutRank,
...
@@ -17,6 +17,8 @@ template <index_t InOutRank,
typename
InDataType
,
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ReduceTensorOp
ReduceOpId
,
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
bool
OutputIndex
>
struct
DevicePoolFwd
:
public
BaseOperator
struct
DevicePoolFwd
:
public
BaseOperator
...
@@ -25,13 +27,14 @@ struct DevicePoolFwd : public BaseOperator
...
@@ -25,13 +27,14 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in_dev
,
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
input_n_c_wis_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_xs_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
output_n_c_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
input_n_c_wis_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
output_n_c_wis_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
indices_n_c_wis_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_xs_strides
,
std
::
vector
<
ck
::
index_t
>
window_xs_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
=
0
;
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
=
0
;
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment