Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6e3c786e
Commit
6e3c786e
authored
Dec 06, 2024
by
Jing Zhang
Browse files
merge develop
parents
1bb510cb
261f1759
Changes
465
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
463 additions
and
96 deletions
+463
-96
codegen/test/rtc/include/rtc/filesystem.hpp
codegen/test/rtc/include/rtc/filesystem.hpp
+60
-0
codegen/test/rtc/include/rtc/tmp_dir.hpp
codegen/test/rtc/include/rtc/tmp_dir.hpp
+2
-2
codegen/test/rtc/src/compile_kernel.cpp
codegen/test/rtc/src/compile_kernel.cpp
+5
-5
codegen/test/rtc/src/tmp_dir.cpp
codegen/test/rtc/src/tmp_dir.cpp
+3
-3
docs/reference/API_Reference_Guide.rst
docs/reference/API_Reference_Guide.rst
+0
-6
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+7
-0
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+21
-20
example/01_gemm/gemm_wmma_bf16.cpp
example/01_gemm/gemm_wmma_bf16.cpp
+84
-0
example/01_gemm/gemm_wmma_int8.cpp
example/01_gemm/gemm_wmma_int8.cpp
+84
-0
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
+12
-1
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
+58
-0
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+23
-20
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+41
-1
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+8
-8
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
+1
-1
example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc
...multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc
+41
-16
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+6
-6
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
+5
-5
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
codegen/test/rtc/include/rtc/filesystem.hpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP
#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP
#include <string>
#include <string_view>
// clang-format off
#if defined(CPPCHECK)
#define RTC_HAS_FILESYSTEM 1
#define RTC_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define RTC_HAS_FILESYSTEM 1
#define RTC_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define RTC_HAS_FILESYSTEM 0
#define RTC_HAS_FILESYSTEM_TS 1
#else
#define RTC_HAS_FILESYSTEM 0
#define RTC_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define RTC_HAS_FILESYSTEM 1
#else
#define RTC_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define RTC_HAS_FILESYSTEM_TS 1
#else
#define RTC_HAS_FILESYSTEM_TS 0
#endif
#else
#define RTC_HAS_FILESYSTEM 0
#define RTC_HAS_FILESYSTEM_TS 0
#endif
// clang-format on
#if RTC_HAS_FILESYSTEM
#include <filesystem>
#elif RTC_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace
rtc
{
#if RTC_HAS_FILESYSTEM
namespace
fs
=
::
std
::
filesystem
;
#elif RTC_HAS_FILESYSTEM_TS
namespace
fs
=
::
std
::
experimental
::
filesystem
;
#endif
}
// namespace rtc
#endif // GUARD_RTC_FILESYSTEM_HPP_
codegen/test/rtc/include/rtc/tmp_dir.hpp
View file @
6e3c786e
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
#include <string>
#include <string>
#include <c
k
/filesystem.hpp>
#include <
rt
c/filesystem.hpp>
namespace
rtc
{
namespace
rtc
{
struct
tmp_dir
struct
tmp_dir
{
{
CK
::
fs
::
path
path
;
fs
::
path
path
;
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
void
execute
(
const
std
::
string
&
cmd
)
const
;
void
execute
(
const
std
::
string
&
cmd
)
const
;
...
...
codegen/test/rtc/src/compile_kernel.cpp
View file @
6e3c786e
#include
"
rtc/hip.hpp
"
#include
<
rtc/hip.hpp
>
#include <rtc/compile_kernel.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/tmp_dir.hpp>
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <stdexcept>
...
@@ -70,9 +70,9 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
...
@@ -70,9 +70,9 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
for
(
const
auto
&
src
:
srcs
)
for
(
const
auto
&
src
:
srcs
)
{
{
CK
::
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
fs
::
path
full_path
=
td
.
path
/
src
.
path
;
CK
::
fs
::
path
parent_path
=
full_path
.
parent_path
();
fs
::
path
parent_path
=
full_path
.
parent_path
();
CK
::
fs
::
create_directories
(
parent_path
);
fs
::
create_directories
(
parent_path
);
write_string
(
full_path
.
string
(),
src
.
content
);
write_string
(
full_path
.
string
(),
src
.
content
);
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
{
...
@@ -86,7 +86,7 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
...
@@ -86,7 +86,7 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
td
.
execute
(
compiler
()
+
options
.
flags
);
td
.
execute
(
compiler
()
+
options
.
flags
);
auto
out_path
=
td
.
path
/
out
;
auto
out_path
=
td
.
path
/
out
;
if
(
not
CK
::
fs
::
exists
(
out_path
))
if
(
not
fs
::
exists
(
out_path
))
throw
std
::
runtime_error
(
"Output file missing: "
+
out
);
throw
std
::
runtime_error
(
"Output file missing: "
+
out
);
auto
obj
=
read_buffer
(
out_path
.
string
());
auto
obj
=
read_buffer
(
out_path
.
string
());
...
...
codegen/test/rtc/src/tmp_dir.cpp
View file @
6e3c786e
...
@@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix)
...
@@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix)
}
}
tmp_dir
::
tmp_dir
(
const
std
::
string
&
prefix
)
tmp_dir
::
tmp_dir
(
const
std
::
string
&
prefix
)
:
path
(
CK
::
fs
::
temp_directory_path
()
/
:
path
(
fs
::
temp_directory_path
()
/
unique_string
(
prefix
.
empty
()
?
"ck-rtc"
:
"ck-rtc-"
+
prefix
))
unique_string
(
prefix
.
empty
()
?
"ck-rtc"
:
"ck-rtc-"
+
prefix
))
{
{
CK
::
fs
::
create_directories
(
this
->
path
);
fs
::
create_directories
(
this
->
path
);
}
}
void
tmp_dir
::
execute
(
const
std
::
string
&
cmd
)
const
void
tmp_dir
::
execute
(
const
std
::
string
&
cmd
)
const
...
@@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const
...
@@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const
std
::
system
(
s
.
c_str
());
std
::
system
(
s
.
c_str
());
}
}
tmp_dir
::~
tmp_dir
()
{
CK
::
fs
::
remove_all
(
this
->
path
);
}
tmp_dir
::~
tmp_dir
()
{
fs
::
remove_all
(
this
->
path
);
}
}
// namespace rtc
}
// namespace rtc
docs/reference/API_Reference_Guide.rst
View file @
6e3c786e
...
@@ -12,12 +12,6 @@ API reference guide
...
@@ -12,12 +12,6 @@ API reference guide
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
some of the key design principles that are used to write new classes that extend CK functionality.
some of the key design principles that are used to write new classes that extend CK functionality.
=================
Using CK API
=================
This section describes how to use the CK library API.
=================
=================
CK Datatypes
CK Datatypes
=================
=================
...
...
docs/sphinx/requirements.in
View file @
6e3c786e
rocm-docs-core==1.
8.2
rocm-docs-core==1.
11.0
sphinxcontrib-bibtex==2.6.3
sphinxcontrib-bibtex==2.6.3
docs/sphinx/requirements.txt
View file @
6e3c786e
...
@@ -103,7 +103,7 @@ requests==2.32.3
...
@@ -103,7 +103,7 @@ requests==2.32.3
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==1.
8.2
rocm-docs-core==1.
11.0
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via pybtex
# via pybtex
...
...
example/01_gemm/CMakeLists.txt
View file @
6e3c786e
...
@@ -79,9 +79,16 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
...
@@ -79,9 +79,16 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_executable
(
example_gemm_xdl_fp8_streamk_v3 gemm_xdl_fp8_streamk_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_streamk_v3
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_custom_target
(
example_gemm_wmma
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_example_executable
(
example_gemm_wmma_bf16 gemm_wmma_bf16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_bf16
)
add_example_executable
(
example_gemm_wmma_int8 gemm_wmma_int8.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_int8
)
example/01_gemm/common.hpp
View file @
6e3c786e
...
@@ -29,9 +29,9 @@ struct ProblemSize final
...
@@ -29,9 +29,9 @@ struct ProblemSize final
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
0
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
0
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
0
;
ck
::
index_t
StrideC
=
-
1
;
};
};
struct
ProblemSizeStreamK
final
struct
ProblemSizeStreamK
final
...
@@ -40,11 +40,11 @@ struct ProblemSizeStreamK final
...
@@ -40,11 +40,11 @@ struct ProblemSizeStreamK final
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
0
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
0
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
0
;
ck
::
index_t
StrideC
=
-
1
;
ck
::
index_t
NumSKBlocks
=
-
1
;
ck
::
index_t
NumSKBlocks
=
-
1
;
// number of stream-k blocks
};
};
struct
ProblemSizeStreamK_universal
final
struct
ProblemSizeStreamK_universal
final
{
{
...
@@ -52,9 +52,9 @@ struct ProblemSizeStreamK_universal final
...
@@ -52,9 +52,9 @@ struct ProblemSizeStreamK_universal final
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
0
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
0
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
0
;
ck
::
index_t
StrideC
=
-
1
;
ck
::
index_t
Grid_size
=
-
1
;
// defaults to max occupancy
ck
::
index_t
Grid_size
=
-
1
;
// defaults to max occupancy
ck
::
index_t
Streamk_sel
=
1
;
// defaults to 1-tile SK
ck
::
index_t
Streamk_sel
=
1
;
// defaults to 1-tile SK
...
@@ -66,18 +66,19 @@ struct ProblemSizeSplitK final
...
@@ -66,18 +66,19 @@ struct ProblemSizeSplitK final
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
0
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
0
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideC
=
0
;
ck
::
index_t
StrideC
=
-
1
;
ck
::
index_t
KBatch
=
1
;
ck
::
index_t
KBatch
=
1
;
};
};
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
int
init_method
=
2
;
int
do_verification
=
1
;
bool
time_kernel
=
false
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
};
};
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
...
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
...
@@ -126,7 +127,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
...
@@ -176,7 +177,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
else
else
{
{
std
::
cerr
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
...
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
...
@@ -225,7 +226,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
...
@@ -275,7 +276,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU and GPU)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=CPU
, 2=GPU, 3=CPU
and GPU)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
...
...
example/01_gemm/gemm_wmma_bf16.cpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
bhalf_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
// Prefetch stage
128
,
// BlockSize
64
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
2
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstanceGPU
=
ck
::
tensor_operation
::
device
::
ReferenceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/01_gemm/gemm_wmma_int8.cpp
0 → 100644
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
CShuffleDataType
=
int32_t
;
using
CDataType
=
int8_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
// Prefetch stage
128
,
// BlockSize
64
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
2
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
2
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstanceGPU
=
ck
::
tensor_operation
::
device
::
ReferenceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
View file @
6e3c786e
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_
t
;
using
CShuffleDataType
=
floa
t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
...
@@ -43,6 +43,17 @@ using DeviceGemmV2_Streamk_Instance =
...
@@ -43,6 +43,17 @@ using DeviceGemmV2_Streamk_Instance =
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstanceGPU
=
ck
::
tensor_operation
::
device
::
ReferenceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example_streamk_v2.inc"
#include "run_gemm_example_streamk_v2.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_universal_streamk_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_universal_streamk_example
(
argc
,
argv
);
}
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
0 → 100755
View file @
6e3c786e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmV2_Streamk_Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle_Streamk_V3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
256
,
128
,
16
,
16
,
16
,
16
,
4
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
ck
::
f8_t
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstanceGPU
=
ck
::
tensor_operation
::
device
::
ReferenceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example_streamk_v2.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_universal_streamk_example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example.inc
View file @
6e3c786e
...
@@ -34,21 +34,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -34,21 +34,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};
};
auto
f_get_default_stride
=
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size
_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index
_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
if
(
stride
==
-
1
)
{
{
// give a chance if stride is
zero
, return a default packed stride
// give a chance if stride is
-1
, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
col
;
return
static_cast
<
std
::
size_t
>
(
col
)
;
}
}
else
else
{
{
return
row
;
return
static_cast
<
std
::
size_t
>
(
row
)
;
}
}
}
}
else
else
return
stride
;
return
static_cast
<
std
::
size_t
>
(
stride
)
;
};
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
...
@@ -61,8 +61,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -61,8 +61,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
case
0
:
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cas
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_conver
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cas
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_conver
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
@@ -248,7 +248,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -248,7 +248,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
// CPU verification
// CPU verification
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
@@ -271,13 +271,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -271,13 +271,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
!
ck
::
utils
::
check_err
(
c_m_n_device_result
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
get_atol
<
CDataType
>
());
#endif
#endif
}
if
((
config
.
do_verification
==
2
)
||
(
config
.
do_verification
==
3
))
{
// GPU verification
// GPU verification
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
...
@@ -299,14 +302,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -299,14 +302,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
c_m_n_device_ref_buf
.
FromDevice
(
c_m_n_device_ref_result
.
mData
.
data
());
c_m_n_device_ref_buf
.
FromDevice
(
c_m_n_device_ref_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
!
ck
::
utils
::
check_err
(
c_m_n_device_result
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_device_ref_result
,
c_m_n_device_ref_result
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
get_atol
<
CDataType
>
());
}
}
return
!
pass
;
return
pass
==
true
;
}
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
100644 → 100755
View file @
6e3c786e
...
@@ -94,6 +94,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -94,6 +94,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_ref_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
...
@@ -114,6 +115,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -114,6 +115,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_ref_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_ref_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
@@ -158,8 +161,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -158,8 +161,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return
true
;
return
true
;
}
}
std
::
size_t
workspace_size
=
gemm
.
GetWorkSpaceSize
(
&
argument
);
if
(
workspace_size
!=
0
)
{
workspace
.
Realloc
(
workspace_size
);
gemm
.
SetWorkSpacePointer
(
&
argument
,
workspace
.
GetDeviceBuffer
());
}
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -189,6 +199,36 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -189,6 +199,36 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#endif
#endif
}
}
if
((
config
.
do_verification
==
2
)
||
(
config
.
do_verification
==
3
))
{
// GPU verification
auto
ref_gemm_gpu
=
ReferenceGemmInstanceGPU
{};
auto
ref_invoker_gpu
=
ref_gemm_gpu
.
MakeInvoker
();
auto
ref_argument_gpu
=
ref_gemm_gpu
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_ref_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
a_element_op
,
b_element_op
,
c_element_op
);
std
::
cout
<<
"Running verification on GPU."
<<
std
::
endl
;
ref_invoker_gpu
.
Run
(
ref_argument_gpu
,
StreamConfig
{});
c_m_n_device_ref_buf
.
FromDevice
(
c_m_n_device_ref_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_device_ref_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
}
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
6e3c786e
...
@@ -33,21 +33,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -33,21 +33,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};
};
auto
f_get_default_stride
=
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size
_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index
_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
if
(
stride
==
-
1
)
{
{
// give a chance if stride is
zero
, return a default packed stride
// give a chance if stride is
-1
, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
col
;
return
static_cast
<
std
::
size_t
>
(
col
)
;
}
}
else
else
{
{
return
row
;
return
static_cast
<
std
::
size_t
>
(
row
)
;
}
}
}
}
else
else
return
stride
;
return
static_cast
<
std
::
size_t
>
(
stride
)
;
};
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
...
@@ -146,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -146,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
(
config
.
do_verification
==
1
)
||
(
config
.
do_verification
==
3
)
)
{
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -179,7 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -179,7 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
ave_time
=
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
5
,
10
,
true
,
4
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
5
0
,
10
0
,
true
,
4
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
...
example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp
View file @
6e3c786e
...
@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout;
...
@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout;
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
...
...
example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc
View file @
6e3c786e
...
@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
...
@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
Tensor
<
EDataType
>
conv_output_device
(
conv_output_g_n_k_wos_desc
);
Tensor
<
EDataType
>
conv_output_device
(
conv_output_g_n_k_wos_desc
);
Tensor
<
R0DataType
>
r0_device
(
r0_desc
);
Tensor
<
R0DataType
>
r0_device
(
r0_desc
);
std
::
cout
<<
"input: "
<<
conv_input
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"weight: "
<<
conv_weight
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"output: "
<<
conv_output_device
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"reduction: "
<<
r0_device
.
mDesc
<<
std
::
endl
<<
std
::
endl
;
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
8
,
7
}(
conv_input
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
8
,
7
}(
conv_input
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
8
,
7
}(
conv_weight
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
1
,
1
}(
conv_weight
);
break
;
case
2
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
8
,
7
}(
conv_input
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1
,
1
}(
conv_weight
);
break
;
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
5
,
5
}(
conv_input
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
8
,
7
}(
conv_input
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
5
,
5
}(
conv_weight
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1
,
1
}(
conv_weight
);
}
}
DeviceMem
conv_input_device_buf
(
sizeof
(
ADataType
)
*
conv_input
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
conv_input_device_buf
(
sizeof
(
ADataType
)
*
conv_input
.
mDesc
.
GetElementSpaceSize
());
...
@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
...
@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
return
false
;
return
false
;
}
}
// XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0.
r0_device_buf
.
SetValue
(
ck
::
NumericLimits
<
R0DataType
>::
Lowest
());
const
float
avg_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
const
float
avg_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
const
std
::
size_t
flop
=
problem_size
.
GetFlops
();
if
(
config
.
time_kernel
)
const
std
::
size_t
num_btype
=
problem_size
.
GetByte
<
ADataType
,
BDataType
,
EDataType
>
();
{
const
std
::
size_t
flop
=
problem_size
.
GetFlops
();
const
std
::
size_t
num_btype
=
problem_size
.
GetByte
<
ADataType
,
BDataType
,
EDataType
>
();
const
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
const
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
const
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
const
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
conv
.
GetTypeString
()
<<
std
::
endl
;
<<
" GB/s, "
<<
conv
.
GetTypeString
()
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"FINISHED: "
<<
conv
.
GetTypeString
()
<<
std
::
endl
;
}
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
...
@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
...
@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
BElementOp
{},
BElementOp
{},
PassThrough
{});
PassThrough
{});
std
::
cout
<<
"
\n
Running verification on CPU."
<<
std
::
endl
;
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
Tensor
<
R0DataType
>
r0_host
(
r0_device
.
mDesc
);
Tensor
<
R0DataType
>
r0_host
(
r0_device
.
mDesc
);
...
@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
...
@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
conv_output_device_buf
.
FromDevice
(
conv_output_device
.
mData
.
data
());
conv_output_device_buf
.
FromDevice
(
conv_output_device
.
mData
.
data
());
r0_device_buf
.
FromDevice
(
r0_device
.
mData
.
data
());
r0_device_buf
.
FromDevice
(
r0_device
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
conv_output_device
,
auto
pass
=
ck
::
utils
::
check_err
(
conv_output_device
,
conv_output_host
,
conv_output_host
,
"Error: incorrect results! (Matrix E)"
,
"Error: incorrect results! (Matrix E)"
,
1
e
-
5
f
,
1
e
-
3
f
,
1
e
-
4
f
)
&&
1
e
-
3
f
);
ck
::
utils
::
check_err
(
pass
=
r0_device
,
r0_host
,
"Error: incorrect results! (Matrix R0)"
,
1
e
-
5
f
,
1
e
-
4
f
);
pass
&&
ck
::
utils
::
check_err
(
r0_device
,
r0_host
,
"Error: incorrect results! (Matrix R0)"
,
1
e
-
3
f
,
1
e
-
3
f
);
if
(
pass
)
std
::
cout
<<
"Verification on CPU: PASS"
<<
std
::
endl
;
return
pass
;
}
}
return
true
;
return
true
;
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
View file @
6e3c786e
...
@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
A
DataType
>
{
0.0
,
1.0
});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
0.0
,
1.0
});
}
}
break
;
break
;
default:
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
DDataType
,
0
>
{});
}
}
}
}
}
}
...
@@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// do GEMM
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
gemm
.
SetKBatchSize
(
argument
,
config
.
k_batch
);
gemm
.
SetKBatchSize
(
&
argument
,
config
.
k_batch
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
@@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
DeviceMem
gemm_arg_dev_mem
(
gemm
.
GetDeviceKernelArgSize
(
&
argument
));
DeviceMem
gemm_arg_dev_mem
(
gemm
.
GetDeviceKernelArgSize
(
&
argument
));
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_arg_dev_mem
.
GetDeviceBuffer
());
gemm
.
SetDeviceKernelArgs
(
&
argument
,
gemm_arg_dev_mem
.
GetDeviceBuffer
());
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
1
});
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
View file @
6e3c786e
...
@@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{
{
auto
group_count
=
problem_size
.
group_count
;
auto
group_count
=
problem_size
.
group_count
;
using
KernelArguments
=
ck
::
tensor_operation
::
device
::
GroupedGemm
TileLoop
KernelArgument
s
<
NumDs
>
;
using
KernelArguments
=
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
NumDs
>
;
using
GemmDesc
=
ck
::
tensor_operation
::
device
::
GemmDesc
;
using
GemmDesc
=
ck
::
tensor_operation
::
device
::
GemmDesc
;
// GEMM shape
// GEMM shape
...
@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
for
(
int
j
=
0
;
j
<
NumDs
;
++
j
)
for
(
int
j
=
0
;
j
<
NumDs
;
++
j
)
{
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
A
DataType
>
{
0.0
,
1.0
});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
0.0
,
1.0
});
}
}
break
;
break
;
default:
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
for
(
int
j
=
0
;
j
<
NumDs
;
++
j
)
for
(
int
j
=
0
;
j
<
NumDs
;
++
j
)
{
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
DDataType
,
0
>
{});
}
}
}
}
}
}
...
...
Prev
1
2
3
4
5
6
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment