Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f6ceef78
Commit
f6ceef78
authored
Aug 26, 2024
by
ThomasNing
Browse files
merge with the develop branch
parents
536c5458
25935b57
Changes
240
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
747 additions
and
95 deletions
+747
-95
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
+0
-2
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
+0
-2
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
+0
-2
codegen/test/rtc/CMakeLists.txt
codegen/test/rtc/CMakeLists.txt
+0
-2
codegen/test/rtc/src/kernel.cpp
codegen/test/rtc/src/kernel.cpp
+1
-1
codegen/test/rtc/src/tmp_dir.cpp
codegen/test/rtc/src/tmp_dir.cpp
+1
-1
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+57
-53
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+2
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+5
-5
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
+2
-2
example/12_reduce/reduce_blockwise.cpp
example/12_reduce/reduce_blockwise.cpp
+28
-1
example/12_reduce/reduce_blockwise_impl.hpp
example/12_reduce/reduce_blockwise_impl.hpp
+12
-2
example/12_reduce/reduce_example_common.hpp
example/12_reduce/reduce_example_common.hpp
+3
-2
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
+34
-11
example/20_grouped_conv_bwd_weight/common.hpp
example/20_grouped_conv_bwd_weight/common.hpp
+2
-6
example/62_convnd_activ/CMakeLists.txt
example/62_convnd_activ/CMakeLists.txt
+1
-0
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
+14
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
...v/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
+502
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
...iv/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
+82
-0
No files found.
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
View file @
f6ceef78
...
...
@@ -92,7 +92,6 @@ struct Epilogue
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Y
),
static_cast
<
int
>
(
prob
.
X
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_lengths
=
{};
ck
::
Array
<
ck
::
index_t
,
5
>
in_strides
{
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Hi
*
prob
.
Wi
*
prob
.
G
*
prob
.
C
),
...
...
@@ -109,7 +108,6 @@ struct Epilogue
1
,
static_cast
<
int
>
(
prob
.
X
*
prob
.
C
),
static_cast
<
int
>
(
prob
.
C
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_strides
=
{};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_strides
=
{
1
,
1
};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_dilations
=
{
1
,
1
};
...
...
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
View file @
f6ceef78
...
...
@@ -92,7 +92,6 @@ struct Epilogue
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Y
),
static_cast
<
int
>
(
prob
.
X
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_lengths
=
{};
ck
::
Array
<
ck
::
index_t
,
5
>
in_strides
{
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Hi
*
prob
.
Wi
*
prob
.
G
*
prob
.
C
),
...
...
@@ -109,7 +108,6 @@ struct Epilogue
1
,
static_cast
<
int
>
(
prob
.
X
*
prob
.
C
),
static_cast
<
int
>
(
prob
.
C
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_strides
=
{};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_strides
=
{
2
,
2
};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_dilations
=
{
1
,
1
};
...
...
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
View file @
f6ceef78
...
...
@@ -92,7 +92,6 @@ struct Epilogue
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Y
),
static_cast
<
int
>
(
prob
.
X
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_lengths
=
{};
ck
::
Array
<
ck
::
index_t
,
5
>
in_strides
{
static_cast
<
int
>
(
prob
.
C
),
static_cast
<
int
>
(
prob
.
Hi
*
prob
.
Wi
*
prob
.
G
*
prob
.
C
),
...
...
@@ -109,7 +108,6 @@ struct Epilogue
1
,
static_cast
<
int
>
(
prob
.
X
*
prob
.
C
),
static_cast
<
int
>
(
prob
.
C
)};
ck
::
Array
<
ck
::
index_t
,
5
>
d_strides
=
{};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_strides
=
{
1
,
1
};
ck
::
Array
<
ck
::
index_t
,
2
>
conv_filter_dilations
=
{
1
,
1
};
...
...
codegen/test/rtc/CMakeLists.txt
View file @
f6ceef78
find_package
(
hip
)
file
(
GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
target_include_directories
(
ck_rtc PUBLIC include
)
...
...
codegen/test/rtc/src/kernel.cpp
View file @
f6ceef78
codegen/test/rtc/src/tmp_dir.cpp
View file @
f6ceef78
docs/sphinx/requirements.in
View file @
f6ceef78
rocm-docs-core==1.
6.0
rocm-docs-core==1.
7.2
sphinxcontrib-bibtex==2.6.2
docs/sphinx/requirements.txt
View file @
f6ceef78
...
...
@@ -4,33 +4,33 @@
#
# pip-compile requirements.in
#
accessible-pygments==0.0.
3
accessible-pygments==0.0.
5
# via pydata-sphinx-theme
alabaster==0.7.1
3
alabaster==0.7.1
6
# via sphinx
babel==2.1
2.1
babel==2.1
5.0
# via
# pydata-sphinx-theme
# sphinx
beautifulsoup4==4.1
1.2
beautifulsoup4==4.1
2.3
# via pydata-sphinx-theme
breathe==4.3
4
.0
breathe==4.3
5
.0
# via rocm-docs-core
certifi==202
3
.7.
22
certifi==202
4
.7.
4
# via requests
cffi==1.1
5.1
cffi==1.1
6.0
# via
# cryptography
# pynacl
charset-normalizer==3.
1.0
charset-normalizer==3.
3.2
# via requests
click==8.1.
3
click==8.1.
7
# via sphinx-external-toc
cryptography==4
1
.0.
6
cryptography==4
3
.0.
0
# via pyjwt
deprecated==1.2.1
3
deprecated==1.2.1
4
# via pygithub
docutils==0.
16
docutils==0.
21.2
# via
# breathe
# myst-parser
...
...
@@ -38,35 +38,35 @@ docutils==0.16
# pydata-sphinx-theme
# sphinx
# sphinxcontrib-bibtex
fastjsonschema==2.
18
.0
fastjsonschema==2.
20
.0
# via rocm-docs-core
gitdb==4.0.1
0
gitdb==4.0.1
1
# via gitpython
gitpython==3.1.3
7
gitpython==3.1.
4
3
# via rocm-docs-core
idna==3.
4
idna==3.
7
# via requests
imagesize==1.4.1
# via sphinx
jinja2==3.1.
2
jinja2==3.1.
4
# via
# myst-parser
# sphinx
latexcodec==
2
.0.
1
latexcodec==
3
.0.
0
# via pybtex
markdown-it-py==
2.2
.0
markdown-it-py==
3.0
.0
# via
# mdit-py-plugins
# myst-parser
markupsafe==2.1.
2
markupsafe==2.1.
5
# via jinja2
mdit-py-plugins==0.
3.5
mdit-py-plugins==0.
4.1
# via myst-parser
mdurl==0.1.2
# via markdown-it-py
myst-parser==
1
.0.
0
myst-parser==
3
.0.
1
# via rocm-docs-core
packaging==2
3.0
packaging==2
4.1
# via
# pydata-sphinx-theme
# sphinx
...
...
@@ -74,48 +74,46 @@ pybtex==0.24.0
# via
# pybtex-docutils
# sphinxcontrib-bibtex
pybtex-docutils==1.0.
2
pybtex-docutils==1.0.
3
# via sphinxcontrib-bibtex
pycparser==2.2
1
pycparser==2.2
2
# via cffi
pydata-sphinx-theme==0.1
3.3
pydata-sphinx-theme==0.1
5.4
# via
# rocm-docs-core
# sphinx-book-theme
pygithub==
1.58.1
pygithub==
2.3.0
# via rocm-docs-core
pygments==2.1
5
.0
pygments==2.1
8
.0
# via
# accessible-pygments
# pydata-sphinx-theme
# sphinx
pyjwt[crypto]==2.
6
.0
pyjwt[crypto]==2.
8
.0
# via pygithub
pynacl==1.5.0
# via pygithub
pyyaml==6.0
pyyaml==6.0
.1
# via
# myst-parser
# pybtex
# rocm-docs-core
# sphinx-external-toc
requests==2.3
1.0
requests==2.3
2.3
# via
# pygithub
# sphinx
rocm-docs-core==1.
6.0
rocm-docs-core==1.
7.2
# via -r requirements.in
six==1.16.0
# via
# latexcodec
# pybtex
smmap==5.0.0
# via pybtex
smmap==5.0.1
# via gitdb
snowballstemmer==2.2.0
# via sphinx
soupsieve==2.
4
soupsieve==2.
5
# via beautifulsoup4
sphinx==
5.3.0
sphinx==
7.4.7
# via
# breathe
# myst-parser
...
...
@@ -127,33 +125,39 @@ sphinx==5.3.0
# sphinx-external-toc
# sphinx-notfound-page
# sphinxcontrib-bibtex
sphinx-book-theme==1.
0.1
sphinx-book-theme==1.
1.3
# via rocm-docs-core
sphinx-copybutton==0.5.
1
sphinx-copybutton==0.5.
2
# via rocm-docs-core
sphinx-design==0.
4.1
sphinx-design==0.
6.0
# via rocm-docs-core
sphinx-external-toc==
0.3
.1
sphinx-external-toc==
1.0
.1
# via rocm-docs-core
sphinx-notfound-page==
0.8
.3
sphinx-notfound-page==
1.0
.3
# via rocm-docs-core
sphinxcontrib-applehelp==
1
.0.
4
sphinxcontrib-applehelp==
2
.0.
0
# via sphinx
sphinxcontrib-bibtex==2.6.2
# via -r requirements.in
sphinxcontrib-devhelp==
1
.0.
2
sphinxcontrib-devhelp==
2
.0.
0
# via sphinx
sphinxcontrib-htmlhelp==2.
0.1
sphinxcontrib-htmlhelp==2.
1.0
# via sphinx
sphinxcontrib-jsmath==1.0.1
# via sphinx
sphinxcontrib-qthelp==
1
.0.
3
sphinxcontrib-qthelp==
2
.0.
0
# via sphinx
sphinxcontrib-serializinghtml==
1.1.5
sphinxcontrib-serializinghtml==
2.0.0
# via sphinx
typing-extensions==4.5.0
# via pydata-sphinx-theme
urllib3==1.26.18
# via requests
wrapt==1.15.0
tomli==2.0.1
# via sphinx
typing-extensions==4.12.2
# via
# pydata-sphinx-theme
# pygithub
urllib3==2.2.2
# via
# pygithub
# requests
wrapt==1.16.0
# via deprecated
example/01_gemm/gemm_xdl_fp8.cpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
...
...
@@ -7,7 +7,7 @@
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
using
CDataType
=
ck
::
hal
f_t
;
using
CDataType
=
ck
::
f
8
_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
...
...
example/01_gemm/run_gemm_example.inc
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -34,11 +34,11 @@ inline __host__ __device__ constexpr double get_rtol()
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5
e-1
;
// 57344 and 49152 are acceptable
return
2
e
-
1
;
}
else
{
...
...
@@ -75,11 +75,11 @@ inline __host__ __device__ constexpr double get_atol()
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
return
2
e
-
1
;
}
else
{
...
...
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cassert>
...
...
@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
inline
HostTensorDescriptor
make_r0_host_tensor_descriptor
(
const
ck
::
utils
::
conv
::
ConvParam
&
problem_size
)
{
std
::
vector
<
ck
::
index_t
>
dimensions
{
problem_size
.
G_
,
problem_size
.
N_
};
std
::
vector
<
ck
::
long_
index_t
>
dimensions
{
problem_size
.
G_
,
problem_size
.
N_
};
ck
::
ranges
::
copy
(
problem_size
.
output_spatial_lengths_
,
std
::
back_inserter
(
dimensions
));
...
...
example/12_reduce/reduce_blockwise.cpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
...
...
@@ -255,34 +255,61 @@ int main(int argc, char* argv[])
else
{
// for testing half_t
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
half_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
half_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing float
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing double
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
float
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing bhalf_t
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
bhalf_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
ck
::
bhalf_t
,
float
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing int8_t
pass
=
pass
&&
reduce_blockwise_test
<
int8_t
,
int32_t
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
int8_t
,
int32_t
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// for testing int4_t using AVG operation
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int32_t
,
ReduceTensorOp
::
AVG
,
false
,
false
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int32_t
,
ReduceTensorOp
::
AVG
,
false
,
false
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
// for testing int4_t using MAX operation
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int8_t
,
ReduceTensorOp
::
MAX
,
false
,
false
>
(
true
,
2
,
true
,
{
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
pass
=
pass
&&
reduce_blockwise_test
<
int4_t
,
int8_t
,
ReduceTensorOp
::
MAX
,
false
,
false
>
(
true
,
2
,
true
,
{
16
,
64
,
32
,
960
},
{
0
,
1
,
2
},
1.0
f
,
0.0
f
);
#endif
...
...
example/12_reduce/reduce_blockwise_impl.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification,
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
int
log_level
=
0
,
cold_niters
=
5
,
nrepeat
=
50
;
if
(
beta
!=
0.0
f
)
{
std
::
cerr
<<
"Warning: With beta != 0.0f there must be only one repeat for correct results "
"since out memory is being overwritten."
<<
std
::
endl
;
cold_niters
=
0
;
nrepeat
=
1
;
}
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
log_level
,
cold_niters
,
nrepeat
});
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InOutDataType
)
+
invariant_total_length
*
sizeof
(
InOutDataType
);
...
...
example/12_reduce/reduce_example_common.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -38,7 +38,8 @@ struct ReduceShape
static
constexpr
ck
::
index_t
NumReduceDim_
=
NumReduceDim
;
};
using
reduce_shape_instances
=
std
::
tuple
<
ReduceShape
<
3
,
1
>
,
using
reduce_shape_instances
=
std
::
tuple
<
ReduceShape
<
12
,
3
>
,
ReduceShape
<
3
,
1
>
,
ReduceShape
<
3
,
2
>
,
ReduceShape
<
4
,
1
>
,
ReduceShape
<
4
,
2
>
,
...
...
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero
in_device_buf
.
SetZero
();
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
input_left_pads_i32
(
NDimSpatial
);
std
::
vector
<
ck
::
index_t
>
input_right_pads_i32
(
NDimSpatial
);
for
(
ck
::
index_t
d
=
0
;
d
<
NDimSpatial
;
d
++
)
{
input_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_spatial_lengths_
[
d
]);
filter_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
filter_spatial_lengths_
[
d
]);
output_spatial_lengths_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
GetOutputSpatialLengths
()[
d
]);
conv_filter_strides_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
conv_filter_strides_
[
d
]);
conv_filter_dilations_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
conv_filter_dilations_
[
d
]);
input_left_pads_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_left_pads_
[
d
]);
input_right_pads_i32
[
d
]
=
static_cast
<
ck
::
index_t
>
(
conv_param
.
input_right_pads_
[
d
]);
}
// do GEMM
auto
conv
=
DeviceConvNdBwdDataInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
...
...
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv
.
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
conv_param
.
input_spatial_lengths_
,
conv_param
.
filter_spatial_lengths_
,
conv_param
.
GetO
utput
S
patial
L
engths
()
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
N_
)
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
K_
)
,
static_cast
<
ck
::
index_t
>
(
conv_param
.
C_
)
,
input_spatial_lengths_
i32
,
filter_spatial_lengths_
i32
,
o
utput
_s
patial
_l
engths
_i32
,
conv_filter_strides_
i32
,
conv_filter_dilations_
i32
,
input_left_pads_
i32
,
input_right_pads_
i32
,
in_element_op
,
wei_element_op
,
out_element_op
);
...
...
example/20_grouped_conv_bwd_weight/common.hpp
View file @
f6ceef78
...
...
@@ -23,12 +23,8 @@
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
#ifdef CK_ENABLE_FP8
using
F8
=
ck
::
f8_t
;
#endif
#ifdef CK_ENABLE_BF8
using
BF8
=
ck
::
bf8_t
;
#endif
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
example/62_convnd_activ/CMakeLists.txt
View file @
f6ceef78
...
...
@@ -3,6 +3,7 @@ add_subdirectory(convinvscale)
add_subdirectory
(
convscale
)
add_subdirectory
(
convscale_relu
)
add_subdirectory
(
convscale_add
)
add_subdirectory
(
convscale_reduce
)
add_subdirectory
(
multi_AB
)
add_subdirectory
(
unary
)
...
...
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
0 → 100644
View file @
f6ceef78
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_activ_xdl_convscale_reduce
)
add_example_executable
(
example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
)
add_example_dependencies
(
example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8
)
add_example_executable
(
example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp
)
add_example_dependencies
(
example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8
)
set
(
target 1
)
endif
()
endforeach
()
example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
0 → 100644
View file @
f6ceef78
This diff is collapsed.
Click to expand it.
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
0 → 100644
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
ConvOutDataType
=
float
;
// data type of convolution result
using
OutDataType
=
ck
::
f8_t
;
// data type of final result
using
AComputeDataType
=
ck
::
f8_t
;
using
BComputeDataType
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
ConvScale
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
using
DeviceGroupedConvNDFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
ConvOutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
1
,
//
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
1
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
1
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
AComputeDataType
,
BComputeDataType
>
;
#include "run_convnd_fwd_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_convnd_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment