Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f6ceef78
Commit
f6ceef78
authored
Aug 26, 2024
by
ThomasNing
Browse files
merge with the develop branch
parents
536c5458
25935b57
Changes
240
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
747 additions
and
95 deletions
+747
-95
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
+0
-2
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
+0
-2
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
+0
-2
codegen/test/rtc/CMakeLists.txt
codegen/test/rtc/CMakeLists.txt
+0
-2
codegen/test/rtc/src/kernel.cpp
codegen/test/rtc/src/kernel.cpp
+1
-1
codegen/test/rtc/src/tmp_dir.cpp
codegen/test/rtc/src/tmp_dir.cpp
+1
-1
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+57
-53
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+2
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+5
-5
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
+2
-2
example/12_reduce/reduce_blockwise.cpp
example/12_reduce/reduce_blockwise.cpp
+28
-1
example/12_reduce/reduce_blockwise_impl.hpp
example/12_reduce/reduce_blockwise_impl.hpp
+12
-2
example/12_reduce/reduce_example_common.hpp
example/12_reduce/reduce_example_common.hpp
+3
-2
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
+34
-11
example/20_grouped_conv_bwd_weight/common.hpp
example/20_grouped_conv_bwd_weight/common.hpp
+2
-6
example/62_convnd_activ/CMakeLists.txt
example/62_convnd_activ/CMakeLists.txt
+1
-0
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
+14
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
...v/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
+502
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
...iv/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
+82
-0
No files found.
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
View file @
f6ceef78
...
...
@@ -92,7 +92,6 @@ struct Epilogue
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Y
),
static_cast
<
int
>
(
prob
.
X
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_lengths
=
{};
ck
::
Array
<
ck
::
index_t
,
5
>
in_strides
{
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Hi
*
prob
.
Wi
*
prob
.
G
*
prob
.
C
),
...
...
@@ -109,7 +108,6 @@ struct Epilogue
1
,
static_cast
<
int
>
(
prob
.
X
*
prob
.
C
),
static_cast
<
int
>
(
prob
.
C
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_strides
=
{};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_strides
=
{
1
,
1
};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_dilations
=
{
1
,
1
};
...
...
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
View file @
f6ceef78
...
...
@@ -92,7 +92,6 @@ struct Epilogue
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Y
),
static_cast
<
int
>
(
prob
.
X
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_lengths
=
{};
ck
::
Array
<
ck
::
index_t
,
5
>
in_strides
{
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Hi
*
prob
.
Wi
*
prob
.
G
*
prob
.
C
),
...
...
@@ -109,7 +108,6 @@ struct Epilogue
1
,
static_cast
<
int
>
(
prob
.
X
*
prob
.
C
),
static_cast
<
int
>
(
prob
.
C
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_strides
=
{};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_strides
=
{
2
,
2
};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_dilations
=
{
1
,
1
};
...
...
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
View file @
f6ceef78
...
...
@@ -92,7 +92,6 @@ struct Epilogue
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Y
),
static_cast
<
int
>
(
prob
.
X
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_lengths
=
{};
ck
::
Array
<
ck
::
index_t
,
5
>
in_strides
{
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Hi
*
prob
.
Wi
*
prob
.
G
*
prob
.
C
),
...
...
@@ -109,7 +108,6 @@ struct Epilogue
1
,
static_cast
<
int
>
(
prob
.
X
*
prob
.
C
),
static_cast
<
int
>
(
prob
.
C
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_strides
=
{};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_strides
=
{
1
,
1
};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_dilations
=
{
1
,
1
};
...
...
codegen/test/rtc/CMakeLists.txt
View file @
f6ceef78
find_package
(
hip
)
file
(
GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
target_include_directories
(
ck_rtc PUBLIC include
)
...
...
codegen/test/rtc/src/kernel.cpp
View file @
f6ceef78
...
...
@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream,
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
.
data
(),
size
);
}
}
// namespace rtc
\ No newline at end of file
}
// namespace rtc
codegen/test/rtc/src/tmp_dir.cpp
View file @
f6ceef78
...
...
@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const
tmp_dir
::~
tmp_dir
()
{
std
::
filesystem
::
remove_all
(
this
->
path
);
}
}
// namespace rtc
\ No newline at end of file
}
// namespace rtc
docs/sphinx/requirements.in
View file @
f6ceef78
rocm-docs-core==1.
6.0
rocm-docs-core==1.
7.2
sphinxcontrib-bibtex==2.6.2
docs/sphinx/requirements.txt
View file @
f6ceef78
...
...
@@ -4,33 +4,33 @@
#
# pip-compile requirements.in
#
accessible-pygments==0.0.
3
accessible-pygments==0.0.
5
# via pydata-sphinx-theme
alabaster==0.7.1
3
alabaster==0.7.1
6
# via sphinx
babel==2.1
2.1
babel==2.1
5.0
# via
# pydata-sphinx-theme
# sphinx
beautifulsoup4==4.1
1.2
beautifulsoup4==4.1
2.3
# via pydata-sphinx-theme
breathe==4.3
4
.0
breathe==4.3
5
.0
# via rocm-docs-core
certifi==202
3
.7.
22
certifi==202
4
.7.
4
# via requests
cffi==1.1
5.1
cffi==1.1
6.0
# via
# cryptography
# pynacl
charset-normalizer==3.
1.0
charset-normalizer==3.
3.2
# via requests
click==8.1.
3
click==8.1.
7
# via sphinx-external-toc
cryptography==4
1
.0.
6
cryptography==4
3
.0.
0
# via pyjwt
deprecated==1.2.1
3
deprecated==1.2.1
4
# via pygithub
docutils==0.
16
docutils==0.
21.2
# via
# breathe
# myst-parser
...
...
@@ -38,35 +38,35 @@ docutils==0.16
# pydata-sphinx-theme
# sphinx
# sphinxcontrib-bibtex
fastjsonschema==2.
18
.0
fastjsonschema==2.
20
.0
# via rocm-docs-core
gitdb==4.0.1
0
gitdb==4.0.1
1
# via gitpython
gitpython==3.1.3
7
gitpython==3.1.
4
3
# via rocm-docs-core
idna==3.
4
idna==3.
7
# via requests
imagesize==1.4.1
# via sphinx
jinja2==3.1.
2
jinja2==3.1.
4
# via
# myst-parser
# sphinx
latexcodec==
2
.0.
1
latexcodec==
3
.0.
0
# via pybtex
markdown-it-py==
2.2
.0
markdown-it-py==
3.0
.0
# via
# mdit-py-plugins
# myst-parser
markupsafe==2.1.
2
markupsafe==2.1.
5
# via jinja2
mdit-py-plugins==0.
3.5
mdit-py-plugins==0.
4.1
# via myst-parser
mdurl==0.1.2
# via markdown-it-py
myst-parser==
1
.0.
0
myst-parser==
3
.0.
1
# via rocm-docs-core
packaging==2
3.0
packaging==2
4.1
# via
# pydata-sphinx-theme
# sphinx
...
...
@@ -74,48 +74,46 @@ pybtex==0.24.0
# via
# pybtex-docutils
# sphinxcontrib-bibtex
pybtex-docutils==1.0.
2
pybtex-docutils==1.0.
3
# via sphinxcontrib-bibtex
pycparser==2.2
1
pycparser==2.2
2
# via cffi
pydata-sphinx-theme==0.1
3.3
pydata-sphinx-theme==0.1
5.4
# via
# rocm-docs-core
# sphinx-book-theme
pygithub==
1.58.1
pygithub==
2.3.0
# via rocm-docs-core
pygments==2.1
5
.0
pygments==2.1
8
.0
# via
# accessible-pygments
# pydata-sphinx-theme
# sphinx
pyjwt[crypto]==2.
6
.0
pyjwt[crypto]==2.
8
.0
# via pygithub
pynacl==1.5.0
# via pygithub
pyyaml==6.0
pyyaml==6.0
.1
# via
# myst-parser
# pybtex
# rocm-docs-core
# sphinx-external-toc
requests==2.3
1.0
requests==2.3
2.3
# via
# pygithub
# sphinx
rocm-docs-core==1.
6.0
rocm-docs-core==1.
7.2
# via -r requirements.in
six==1.16.0
# via
# latexcodec
# pybtex
smmap==5.0.0
# via pybtex
smmap==5.0.1
# via gitdb
snowballstemmer==2.2.0
# via sphinx
soupsieve==2.
4
soupsieve==2.
5
# via beautifulsoup4
sphinx==
5.3.0
sphinx==
7.4.7
# via
# breathe
# myst-parser
...
...
@@ -127,33 +125,39 @@ sphinx==5.3.0
# sphinx-external-toc
# sphinx-notfound-page
# sphinxcontrib-bibtex
sphinx-book-theme==1.
0.1
sphinx-book-theme==1.
1.3
# via rocm-docs-core
sphinx-copybutton==0.5.
1
sphinx-copybutton==0.5.
2
# via rocm-docs-core
sphinx-design==0.
4.1
sphinx-design==0.
6.0
# via rocm-docs-core
sphinx-external-toc==
0.3
.1
sphinx-external-toc==
1.0
.1
# via rocm-docs-core
sphinx-notfound-page==
0.8
.3
sphinx-notfound-page==
1.0
.3
# via rocm-docs-core
sphinxcontrib-applehelp==
1
.0.
4
sphinxcontrib-applehelp==
2
.0.
0
# via sphinx
sphinxcontrib-bibtex==2.6.2
# via -r requirements.in
sphinxcontrib-devhelp==
1
.0.
2
sphinxcontrib-devhelp==
2
.0.
0
# via sphinx
sphinxcontrib-htmlhelp==2.
0.1
sphinxcontrib-htmlhelp==2.
1.0
# via sphinx
sphinxcontrib-jsmath==1.0.1
# via sphinx
sphinxcontrib-qthelp==
1
.0.
3
sphinxcontrib-qthelp==
2
.0.
0
# via sphinx
sphinxcontrib-serializinghtml==
1.1.5
sphinxcontrib-serializinghtml==
2.0.0
# via sphinx
typing-extensions==4.5.0
# via pydata-sphinx-theme
urllib3==1.26.18
# via requests
wrapt==1.15.0
tomli==2.0.1
# via sphinx
typing-extensions==4.12.2
# via
# pydata-sphinx-theme
# pygithub
urllib3==2.2.2
# via
# pygithub
# requests
wrapt==1.16.0
# via deprecated
example/01_gemm/gemm_xdl_fp8.cpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
...
...
@@ -7,7 +7,7 @@
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
using
CDataType
=
ck
::
hal
f_t
;
using
CDataType
=
ck
::
f
8
_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
...
...
example/01_gemm/run_gemm_example.inc
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -34,11 +34,11 @@ inline __host__ __device__ constexpr double get_rtol()
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5
e-1
;
// 57344 and 49152 are acceptable
return
2
e
-
1
;
}
else
{
...
...
@@ -75,11 +75,11 @@ inline __host__ __device__ constexpr double get_atol()
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
return
2
e
-
1
;
}
else
{
...
...
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cassert>
...
...
@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline
HostTensorDescriptor
make_r0_host_tensor_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
problem_size
)
{
std
::
vector
<
ck
::
index_t
>
dimensions
{
problem_size
.
G_
,
problem_size
.
N_
};
std
::
vector
<
ck
::
long_
index_t
>
dimensions
{
problem_size
.
G_
,
problem_size
.
N_
};
ck
::
ranges
::
copy
(
problem_size
.
output_spatial_lengths_
,
std
::
back_inserter
(
dimensions
));
...
...
example/12_reduce/reduce_blockwise.cpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
...
...
@@ -255,34 +255,61 @@ int main(int argc, char* argv[])
else
{
// for testing half_t
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
half_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
half_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing float
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing double
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing bhalf_t
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
bhalf_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
bhalf_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing int8_t
pass
=
pass
&&
reduce_blockwise_test
<
int8_t
,
int32_t
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
int8_t
,
int32_t
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int32_t
,
ReduceTensorOp
::
AVG
,
false
,
false
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int32_t
,
ReduceTensorOp
::
AVG
,
false
,
false
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing int4_t using MAX operation
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int8_t
,
ReduceTensorOp
::
MAX
,
false
,
false
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int8_t
,
ReduceTensorOp
::
MAX
,
false
,
false
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
#endif
...
...
example/12_reduce/reduce_blockwise_impl.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
int
log_level
=
0
,
cold_niters
=
5
,
nrepeat
=
50
;
if
(
beta
!=
0.0
f
)
{
std
::
cerr
<<
"Warning: With beta != 0.0f there must be only one repeat for correct results "
"since out memory is being overwritten."
<<
std
::
endl
;
cold_niters
=
0
;
nrepeat
=
1
;
}
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
log_level
,
cold_niters
,
nrepeat
});
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InOutDataType
)
+
invariant_total_length
*
sizeof
(
InOutDataType
);
...
...
example/12_reduce/reduce_example_common.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -38,7 +38,8 @@ struct ReduceShape
static
constexpr
ck
::
index_t
NumReduceDim_
=
NumReduceDim
;
};
using
reduce_shape_instances
=
std
::
tuple
<
ReduceShape
<
3
,
1
>
,
using
reduce_shape_instances
=
std
::
tuple
<
ReduceShape
<
12
,
3
>
,
ReduceShape
<
3
,
1
>
,
ReduceShape
<
3
,
2
>
,
ReduceShape
<
4
,
1
>
,
ReduceShape
<
4
,
2
>
,
...
...
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero
in_device_buf
.
SetZero
();
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
input_left_pads_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
input_right_pads_i32
(
NDimSpatial
);
for
(
ck
::
index_t
d
=
0
;
d
<
NDimSpatial
;
d
++
)
{
input_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_spatial_lengths_
[
d
]);
filter_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
filter_spatial_lengths_
[
d
]);
output_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
GetOutputSpatialLengths
()[
d
]);
conv_filter_strides_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
conv_filter_strides_
[
d
]);
conv_filter_dilations_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
conv_filter_dilations_
[
d
]);
input_left_pads_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_left_pads_
[
d
]);
input_right_pads_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_right_pads_
[
d
]);
}
// do GEMM
auto
conv
=
DeviceConvNdBwdDataInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
...
...
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv
.
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
conv_param
.
input_spatial_lengths_
,
conv_param
.
filter_spatial_lengths_
,
conv_param
.
GetO
utput
S
patial
L
engths
()
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
N_
)
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
K_
)
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
C_
)
,
input_spatial_lengths_
i32
,
filter_spatial_lengths_
i32
,
o
utput
_s
patial
_l
engths
_i32
,
conv_filter_strides_
i32
,
conv_filter_dilations_
i32
,
input_left_pads_
i32
,
input_right_pads_
i32
,
in_element_op
,
wei_element_op
,
out_element_op
);
...
...
example/20_grouped_conv_bwd_weight/common.hpp
View file @
f6ceef78
...
...
@@ -23,12 +23,8 @@
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
#ifdef CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
#endif
#ifdef CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
example/62_convnd_activ/CMakeLists.txt
View file @
f6ceef78
...
...
@@ -3,6 +3,7 @@ add_subdirectory(convinvscale)
add_subdirectory
(
convscale
)
add_subdirectory
(
convscale_relu
)
add_subdirectory
(
convscale_add
)
add_subdirectory
(
convscale_reduce
)
add_subdirectory
(
multi_AB
)
add_subdirectory
(
unary
)
...
...
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
0 → 100644
View file @
f6ceef78
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_activ_xdl_convscale_reduce
)
add_example_executable
(
example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
)
add_example_dependencies
(
example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8
)
add_example_executable
(
example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp
)
add_example_dependencies
(
example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8
)
set
(
target 1
)
endif
()
endforeach
()
example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/type.hpp"
namespace
ew
=
ck
::
tensor_operation
::
element_wise
;
using
PassThrough
=
ew
::
PassThrough
;
using
ConvScaleRelu
=
ew
::
UnaryCombinedOp
<
ew
::
Scale
,
ew
::
Scale
,
ew
::
Relu
>
;
using
ConvScale
=
ew
::
UnaryCombinedOp
<
ew
::
Scale
,
ew
::
Scale
,
PassThrough
>
;
using
UnaryScaleConvert
=
ew
::
Scale
;
void
print_helper_msg
()
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1e-1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
ConvOutDataType
,
typename
OutDataType
,
typename
InElementOp
,
typename
WeiElementOp
,
typename
ConvElementOp
,
typename
DeviceConvNDFwdInstance
>
bool
run_grouped_conv_fwd
(
bool
do_verification
,
int
init_method
,
bool
time_kernel
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
InElementOp
&
in_element_op
,
const
WeiElementOp
&
wei_element_op
)
{
Tensor
<
InDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_g_k_c_xs_desc
);
Tensor
<
ConvOutDataType
>
host_conv
(
out_g_n_k_wos_desc
);
Tensor
<
ConvOutDataType
>
device_conv
(
out_g_n_k_wos_desc
);
Tensor
<
OutDataType
>
out_host
(
out_g_n_k_wos_desc
);
Tensor
<
OutDataType
>
out_device
(
out_g_n_k_wos_desc
);
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
break
;
case
11
:
// used for debugging
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
});
wei
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{
1
});
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
1.0
,
1.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
conv_device_buf
(
conv_param
.
GetOutputByte
<
ConvOutDataType
>
());
DeviceMem
out_device_buf
(
conv_param
.
GetOutputByte
<
OutDataType
>
());
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
ck
::
ranges
::
copy
(
x
,
y
.
begin
());
};
copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
a_g_n_c_wis_lengths
);
copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
a_g_n_c_wis_strides
);
copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
b_g_k_c_xs_lengths
);
copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
b_g_k_c_xs_strides
);
copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
e_g_n_k_wos_lengths
);
copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
e_g_n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
copy
(
conv_param
.
input_right_pads_
,
input_right_pads
);
// random scale values
float
scale_in
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
scale_wei
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
scale_out
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"scale_in: "
<<
scale_in
<<
std
::
endl
;
std
::
cout
<<
"scale_wei: "
<<
scale_wei
<<
std
::
endl
;
std
::
cout
<<
"scale_out: "
<<
scale_out
<<
std
::
endl
;
// convolution elementwise operation
auto
conv_element_op
=
ConvElementOp
{
ew
::
Scale
{
scale_in
},
ew
::
Scale
{
scale_wei
},
{}};
auto
scale_convert
=
UnaryScaleConvert
{
scale_out
};
// elementwise scale and type cast
// do Conv
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
conv_invoker
=
conv
.
MakeInvoker
();
auto
conv_argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
wei_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
conv_device_buf
.
GetDeviceBuffer
(),
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
0
>
{},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
0
>
{},
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
conv_element_op
);
if
(
!
conv
.
IsSupportedArgument
(
conv_argument
))
{
throw
std
::
runtime_error
(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
);
}
std
::
string
kernels
=
conv
.
GetTypeString
();
float
avg_time
=
conv_invoker
.
Run
(
conv_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
using
DeviceElementwiseScale
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseImpl
<
ck
::
Tuple
<
ConvOutDataType
>
,
// InDataTypeTuple
ck
::
Tuple
<
OutDataType
>
,
// OutDataTypeTuple
UnaryScaleConvert
,
// UnaryScaleConvert
NDimSpatial
+
3
,
// NumDim
256
,
// BlockSize
128
,
// M0PerBlock
128
,
// M1PerBlock
8
,
// M0PerThread
8
,
// M1PerThread
ck
::
Sequence
<
1
,
0
>
,
// ThreadClusterArrangeOrder
ck
::
Sequence
<
8
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
8
>>
;
// OutScalarPerVectorSeq
auto
device_ew_scale
=
DeviceElementwiseScale
{};
auto
scale_invoker
=
device_ew_scale
.
MakeInvoker
();
auto
scale_argument
=
device_ew_scale
.
MakeArgument
(
e_g_n_k_wos_lengths
,
{
e_g_n_k_wos_strides
},
{
e_g_n_k_wos_strides
},
{
conv_device_buf
.
GetDeviceBuffer
()},
{
out_device_buf
.
GetDeviceBuffer
()},
scale_convert
);
if
(
!
device_ew_scale
.
IsSupportedArgument
(
scale_argument
))
{
throw
std
::
runtime_error
(
"wrong! DeviceElementwiseScale with the specified compilation parameters does "
"not support this problem"
);
}
kernels
+=
std
::
string
(
"
\n\t\t
"
)
+
device_ew_scale
.
GetTypeString
();
avg_time
+=
scale_invoker
.
Run
(
scale_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AMAX
;
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceMultiBlock
<
ConvOutDataType
,
ConvOutDataType
,
ConvOutDataType
,
NDimSpatial
+
3
,
NDimSpatial
+
3
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
ck
::
InMemoryDataOperationEnum
::
Set
,
true
,
// PropagateNan
false
,
// OutputIndex
false
,
// HaveIndexInputIfOutputIndex
256
,
// BlockSize
4
,
// MThreadClusterSize
64
,
// KThreadClusterSize
1
,
// MThreadSliceSize
1
,
// KThreadSliceSize
1
,
// InSrcVectorDim
1
,
// InSrceVectorSize
1
>
;
// OutDstVectorSize
std
::
vector
<
size_t
>
outLengths
=
{
1
};
Tensor
<
ConvOutDataType
>
amax_host
(
outLengths
);
Tensor
<
ConvOutDataType
>
amax_from_device
(
outLengths
);
auto
amax_host_strides
=
amax_host
.
mDesc
.
GetStrides
();
std
::
array
<
int
,
NDimSpatial
+
3
>
reduce_dims
;
std
::
iota
(
reduce_dims
.
begin
(),
reduce_dims
.
end
(),
0
);
// 0,..., NDimSpatial+3-1
std
::
array
<
ck
::
index_t
,
1
>
reduce_out_lengths
{
1
};
std
::
array
<
ck
::
index_t
,
1
>
reduce_out_strides
{
static_cast
<
ck
::
index_t
>
(
amax_host_strides
[
0
])};
DeviceMem
amax_device
(
sizeof
(
ConvOutDataType
)
*
amax_host
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
index_device
;
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
std
::
tie
(
in_elementwise_op
,
acc_elementwise_op
)
=
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
static_cast
<
int32_t
>
(
host_conv
.
mDesc
.
GetElementSize
()));
// Hack convolution output strides for reduction as kernel expects stride 1 for the last
// dimension. It only works because the reduction is done on the whole tensor and result is
// independent of the order of elements.
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
reduction_strides
{};
copy
(
HostTensorDescriptor
(
e_g_n_k_wos_lengths
).
GetStrides
(),
reduction_strides
);
auto
device_reduce
=
DeviceReduceInstance
{};
auto
reduce_invoker
=
device_reduce
.
MakeInvokerPointer
();
auto
reduce_argument
=
device_reduce
.
MakeArgumentPointer
(
e_g_n_k_wos_lengths
,
reduction_strides
,
reduce_out_lengths
,
reduce_out_strides
,
reduce_dims
,
1.0
,
0.0
,
conv_device_buf
.
GetDeviceBuffer
(),
nullptr
,
amax_device
.
GetDeviceBuffer
(),
nullptr
,
in_elementwise_op
,
acc_elementwise_op
);
if
(
!
device_reduce
.
IsSupportedArgument
(
reduce_argument
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! DeviceReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!"
);
};
kernels
+=
std
::
string
(
"
\n\t\t
"
)
+
device_reduce
.
GetTypeString
();
float
reduce_time
=
reduce_invoker
->
Run
(
reduce_argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
if
(
time_kernel
)
std
::
cout
<<
"
\n
Reduce time: "
<<
reduce_time
<<
" ms"
<<
std
::
endl
;
avg_time
+=
reduce_time
;
std
::
size_t
flop
=
conv_param
.
GetFlops
();
// convolution FLOPs
auto
conv_out_elems
=
host_conv
.
GetElementSize
();
// number of elements in conv result tensor
// 3 element-wise scale multipliers + 1 AMAX
std
::
size_t
elementwise_ops
=
3
+
1
;
if
constexpr
(
ck
::
is_same_v
<
ConvElementOp
,
ConvScaleRelu
>
)
{
elementwise_ops
+=
1
;
// +1 element-wise relu
}
flop
+=
elementwise_ops
*
conv_out_elems
;
// convolution + elementwise scaling (in + wei + output byte count)
std
::
size_t
num_btype
=
conv_param
.
GetByte
<
InDataType
,
WeiDataType
,
ConvOutDataType
>
();
num_btype
+=
sizeof
(
float
)
+
sizeof
(
float
);
// + 2 scales
// elementwise scaling + F8 conversion
num_btype
+=
conv_param
.
GetOutputByte
<
ConvOutDataType
>
()
+
sizeof
(
float
)
+
conv_param
.
GetOutputByte
<
OutDataType
>
();
// AMAX
num_btype
+=
conv_param
.
GetOutputByte
<
ConvOutDataType
>
()
+
sizeof
(
float
);
if
(
time_kernel
)
{
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
}
std
::
cout
<<
"
\n
Kernels: "
<<
kernels
<<
std
::
endl
;
if
(
do_verification
)
{
auto
ref_conv
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
InDataType
,
WeiDataType
,
ConvOutDataType
,
InElementOp
,
WeiElementOp
,
ConvElementOp
>
();
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in
,
wei
,
host_conv
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
in_element_op
,
wei_element_op
,
conv_element_op
);
ref_invoker
.
Run
(
ref_argument
);
conv_device_buf
.
FromDevice
(
device_conv
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
out_host
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
scale_convert
(
out_host
(
idx
),
host_conv
(
idx
));
});
std
::
cout
<<
"
\n
Comparing output to reference: "
<<
std
::
endl
;
auto
tight_tol_check
=
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: "
);
if
(
!
tight_tol_check
)
{
std
::
cout
<<
"
\n\t
Recompare applying tolerances...
\n
"
;
std
::
cout
<<
"
\t\t
rtol = "
<<
get_rtol
<
OutDataType
>
()
<<
std
::
endl
;
std
::
cout
<<
"
\t\t
atol = "
<<
get_atol
<
OutDataType
>
()
<<
std
::
endl
;
auto
loose_tol_check
=
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: incorrect convolution results!"
,
get_rtol
<
OutDataType
>
(),
get_atol
<
OutDataType
>
());
if
(
!
loose_tol_check
)
{
return
false
;
}
}
std
::
cout
<<
"Success!"
<<
std
::
endl
;
/// Verify AMAX
using
RefReduceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceReduce
<
ConvOutDataType
,
ConvOutDataType
,
ConvOutDataType
,
NDimSpatial
+
3
,
NDimSpatial
+
3
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
true
,
false
>
;
auto
ref_reduce
=
RefReduceInstance
{};
auto
ref_reduce_invoker
=
ref_reduce
.
MakeInvokerPointer
();
auto
ref_reduce_argument
=
ref_reduce
.
MakeArgumentPointer
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
reduce_out_lengths
,
reduce_out_strides
,
reduce_dims
,
1.0
,
0.0
,
host_conv
.
mData
.
data
(),
nullptr
,
amax_host
.
mData
.
data
(),
nullptr
,
in_elementwise_op
,
acc_elementwise_op
);
if
(
!
ref_reduce
.
IsSupportedArgument
(
ref_reduce_argument
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! RefReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!"
);
};
ref_reduce_invoker
->
Run
(
ref_reduce_argument
.
get
());
amax_device
.
FromDevice
(
amax_from_device
.
mData
.
data
());
std
::
cout
<<
"
\n
amax: "
<<
amax_from_device
.
mData
[
0
]
<<
std
::
endl
;
std
::
cout
<<
"amax_ref: "
<<
amax_host
.
mData
[
0
]
<<
std
::
endl
;
return
ck
::
utils
::
check_err
(
amax_from_device
,
amax_host
,
"Error: incorrect AMAX results!"
);
}
return
true
;
}
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
ConvOutDataType
=
float
;
// data type of convolution result
using
OutDataType
=
ck
::
f8_t
;
// data type of final result
using
AComputeDataType
=
ck
::
f8_t
;
using
BComputeDataType
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
ConvScale
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
using
DeviceGroupedConvNDFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
ConvOutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
1
,
//
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
1
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
1
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
AComputeDataType
,
BComputeDataType
>
;
#include "run_convnd_fwd_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_convnd_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment