Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6bbb94f4
"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "93cec4335fed91f07683e4d69cb7980ca050e64d"
Unverified
Commit
6bbb94f4
authored
Jun 26, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Jun 26, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
4c850c90
a32b1bc6
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
400 additions
and
63 deletions
+400
-63
CMakeLists.txt
CMakeLists.txt
+1
-1
client_example/24_grouped_conv_activation/CMakeLists.txt
client_example/24_grouped_conv_activation/CMakeLists.txt
+6
-1
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp
...grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp
+50
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+14
-23
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+2
-1
codegen/CMakeLists.txt
codegen/CMakeLists.txt
+21
-7
codegen/driver/main.cpp
codegen/driver/main.cpp
+38
-4
codegen/include/ck/host/device_gemm_multiple_d.hpp
codegen/include/ck/host/device_gemm_multiple_d.hpp
+1
-1
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+15
-2
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
+12
-5
codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp
...k/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp
+60
-0
codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp
...t/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp
+56
-0
codegen/include/ck/host/headers.hpp
codegen/include/ck/host/headers.hpp
+0
-1
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+1
-1
codegen/include/ck/host/stringutils.hpp
codegen/include/ck/host/stringutils.hpp
+1
-1
codegen/include/ck/host/types.hpp
codegen/include/ck/host/types.hpp
+13
-5
codegen/include/ck/host/utils.hpp
codegen/include/ck/host/utils.hpp
+3
-2
codegen/src/device_gemm_multiple_d.cpp
codegen/src/device_gemm_multiple_d.cpp
+10
-5
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
...gen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+54
-3
codegen/src/device_grouped_conv_fwd_multiple_abd.cpp
codegen/src/device_grouped_conv_fwd_multiple_abd.cpp
+42
-0
No files found.
CMakeLists.txt
View file @
6bbb94f4
...
@@ -112,7 +112,7 @@ message("checking which targets are supported")
...
@@ -112,7 +112,7 @@ message("checking which targets are supported")
#Setting GPU_TARGETS on command line will override this list
#Setting GPU_TARGETS on command line will override this list
if
(
NOT PROFILER_ONLY
)
if
(
NOT PROFILER_ONLY
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"
gfx900;gfx906;
gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
else
()
else
()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
...
...
client_example/24_grouped_conv_activation/CMakeLists.txt
View file @
6bbb94f4
...
@@ -40,9 +40,14 @@ add_executable(client_conv3d_fwd_convinvscale_fp8
...
@@ -40,9 +40,14 @@ add_executable(client_conv3d_fwd_convinvscale_fp8
grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp
)
grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations
)
# Fwd convscale
# Fwd convscale
add_executable
(
client_conv3d_fwd_convscale_fp8
add_executable
(
client_conv3d_fwd_convscale_fp8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp
)
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_convscale_bf8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_convscale_fp8_bf8
add_executable
(
client_conv3d_fwd_convscale_fp8_bf8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp
)
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
...
...
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using
InDataType
=
ck
::
bf8_t
;
using
WeiDataType
=
ck
::
bf8_t
;
using
CShuffleDataType
=
float
;
using
OutDataType
=
ck
::
f8_t
;
using
AComputeDataType
=
InDataType
;
using
BComputeDataType
=
AComputeDataType
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKZYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
static
constexpr
ck
::
index_t
G
=
1
;
static
constexpr
ck
::
index_t
N
=
64
;
static
constexpr
ck
::
index_t
K
=
128
;
static
constexpr
ck
::
index_t
C
=
64
;
static
constexpr
ck
::
index_t
Z
=
3
;
static
constexpr
ck
::
index_t
Y
=
3
;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Di
=
28
;
static
constexpr
ck
::
index_t
Hi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
3
;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
int
main
()
{
return
run_grouped_conv_fwd_convscale
<
NumDimSpatial
,
InDataType
,
WeiDataType
,
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
,
3
,
AComputeDataType
,
BComputeDataType
>
(
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Z
,
Y
,
X
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
client_example/CMakeLists.txt
View file @
6bbb94f4
...
@@ -6,46 +6,36 @@ if (DTYPES)
...
@@ -6,46 +6,36 @@ if (DTYPES)
add_definitions
(
-DDTYPES
)
add_definitions
(
-DDTYPES
)
if
(
DTYPES MATCHES
"int8"
)
if
(
DTYPES MATCHES
"int8"
)
add_definitions
(
-DCK_ENABLE_INT8
)
add_definitions
(
-DCK_ENABLE_INT8
)
if
(
NOT DEFINED
${
CK_ENABLE_INT8
}
)
set
(
CK_ENABLE_INT8
"ON"
)
set
(
CK_ENABLE_INT8
"ON"
)
endif
()
endif
()
endif
()
if
(
DTYPES MATCHES
"fp8"
)
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
add_definitions
(
-DCK_ENABLE_FP8
)
if
(
NOT DEFINED
${
CK_ENABLE_FP8
}
)
set
(
CK_ENABLE_FP8
"ON"
)
set
(
CK_ENABLE_FP8
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"bf8"
)
add_definitions
(
-DCK_ENABLE_BF8
)
set
(
CK_ENABLE_BF8
"ON"
)
endif
()
endif
()
if
(
DTYPES MATCHES
"fp16"
)
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
add_definitions
(
-DCK_ENABLE_FP16
)
if
(
NOT DEFINED
${
CK_ENABLE_FP16
}
)
set
(
CK_ENABLE_FP16
"ON"
)
set
(
CK_ENABLE_FP16
"ON"
)
endif
()
endif
()
endif
()
if
(
DTYPES MATCHES
"fp32"
)
if
(
DTYPES MATCHES
"fp32"
)
add_definitions
(
-DCK_ENABLE_FP32
)
add_definitions
(
-DCK_ENABLE_FP32
)
if
(
NOT DEFINED
${
CK_ENABLE_FP32
}
)
set
(
CK_ENABLE_FP32
"ON"
)
set
(
CK_ENABLE_FP32
"ON"
)
endif
()
endif
()
endif
()
if
(
DTYPES MATCHES
"fp64"
)
if
(
DTYPES MATCHES
"fp64"
)
add_definitions
(
-DCK_ENABLE_FP64
)
add_definitions
(
-DCK_ENABLE_FP64
)
if
(
NOT DEFINED
${
CK_ENABLE_FP64
}
)
set
(
CK_ENABLE_FP64
"ON"
)
set
(
CK_ENABLE_FP64
"ON"
)
endif
()
endif
()
endif
()
if
(
DTYPES MATCHES
"bf16"
)
if
(
DTYPES MATCHES
"bf16"
)
add_definitions
(
-DCK_ENABLE_BF16
)
add_definitions
(
-DCK_ENABLE_BF16
)
if
(
NOT DEFINED
${
CK_ENABLE_BF16
}
)
set
(
CK_ENABLE_BF16
"ON"
)
set
(
CK_ENABLE_BF16
"ON"
)
endif
()
endif
()
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
else
()
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
if
(
NOT DEFINED
${
CK_ENABLE_ALL_DTYPES
}
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
endif
()
endif
()
endif
()
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
...
@@ -73,7 +63,8 @@ message(STATUS "Build with HIP ${hip_VERSION}")
...
@@ -73,7 +63,8 @@ message(STATUS "Build with HIP ${hip_VERSION}")
# add all example subdir
# add all example subdir
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
FOREACH
(
subdir
${
dir_list
}
)
FOREACH
(
subdir
${
dir_list
}
)
IF
(
IS_DIRECTORY
"
${
subdir
}
"
AND
(
NOT
"
${
subdir
}
"
MATCHES
"build"
))
IF
(
IS_DIRECTORY
"
${
subdir
}
"
AND
(
NOT
"
${
subdir
}
"
MATCHES
"build"
)
AND
(
NOT
"
${
subdir
}
"
MATCHES
".vscode"
))
add_subdirectory
(
${
subdir
}
)
add_subdirectory
(
${
subdir
}
)
ENDIF
()
ENDIF
()
ENDFOREACH
()
ENDFOREACH
()
cmake/EnableCompilerWarnings.cmake
View file @
6bbb94f4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# MIT License
# MIT License
#
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
# Copyright (c) 2017
-2024
Advanced Micro Devices, Inc.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -96,6 +96,7 @@ else()
...
@@ -96,6 +96,7 @@ else()
-Wno-covered-switch-default
-Wno-covered-switch-default
-Wno-unsafe-buffer-usage
-Wno-unsafe-buffer-usage
-Wno-unused-lambda-capture
-Wno-unused-lambda-capture
-Wno-nvcc-compat
)
)
else
()
else
()
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"GNU"
AND
${
COMPILER
}
MATCHES
"CXX"
)
if
(
CMAKE_
${
COMPILER
}
_COMPILER_ID MATCHES
"GNU"
AND
${
COMPILER
}
MATCHES
"CXX"
)
...
...
codegen/CMakeLists.txt
View file @
6bbb94f4
cmake_minimum_required
(
VERSION 3.16
)
cmake_minimum_required
(
VERSION 3.16
)
project
(
composable_kernel_host
)
project
(
composable_kernel_host
LANGUAGES CXX HIP
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
...
@@ -12,24 +12,38 @@ find_package(ROCM)
...
@@ -12,24 +12,38 @@ find_package(ROCM)
include
(
ROCMInstallTargets
)
include
(
ROCMInstallTargets
)
include
(
ROCMTest
)
include
(
ROCMTest
)
add_compile_options
(
-std=c++17
)
find_package
(
hip
)
## HIP
set
(
CMAKE_HIP_PLATFORM amd
)
set
(
CMAKE_HIP_COMPILER
${
CMAKE_CXX_COMPILER
}
)
set
(
CMAKE_HIP_EXTENSIONS ON
)
message
(
"CMAKE_HIP_COMPILER:
${
CMAKE_HIP_COMPILER
}
"
)
# add include directories
include_directories
(
BEFORE
${
PROJECT_BINARY_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
HIP_INCLUDE_DIRS
}
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
include
(
Embed
)
include
(
Embed
)
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${
CK_ROOT
}
/include/ck/*.hpp
)
${
CK_ROOT
}
/include/ck/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"RELATIVE:
${
CK_ROOT
}
/include"
)
message
(
STATUS
"RELATIVE:
${
CK_ROOT
}
/include"
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
CK_ROOT
}
/include
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
CK_ROOT
}
/include
)
add_definitions
(
-std=c++17
)
file
(
GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp
)
file
(
GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp
)
# TODO: Use object library
# TODO: Use object library
add_library
(
ck_host STATIC
${
SOURCES
}
)
add_library
(
ck_host STATIC
${
SOURCES
}
)
target_link_libraries
(
ck_host PRIVATE ck_headers
)
target_link_libraries
(
ck_host PRIVATE ck_headers
)
set_target_properties
(
ck_host PROPERTIES
set_target_properties
(
ck_host PROPERTIES
LINKER_LANGUAGE CXX
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON
)
POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
ck_host PUBLIC
target_include_directories
(
ck_host PUBLIC
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
...
...
codegen/driver/main.cpp
View file @
6bbb94f4
...
@@ -5,24 +5,27 @@
...
@@ -5,24 +5,27 @@
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/stringutils.hpp"
using
ck
::
host
::
Transform
;
using
ck
::
host
::
Transform
;
struct
Emitters
struct
Emitters
{
{
// retrieve the hard-coded instances provided, template them, and then store them in a map
std
::
unordered_map
<
std
::
string
,
std
::
function
<
std
::
vector
<
std
::
string
>
()
>>
m
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
std
::
vector
<
std
::
string
>
()
>>
m
;
template
<
class
T
>
template
<
class
T
>
void
Register
(
const
std
::
string
&
name
)
void
Register
(
const
std
::
string
&
name
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
{
{
m
[
name
]
=
[]
{
m
[
name
]
=
[
&
]
{
auto
configs
=
T
::
CreateOperations
();
auto
configs
=
T
::
CreateOperations
(
prologue
,
epilogue
);
return
Transform
(
configs
,
[](
const
auto
&
ops
)
{
return
ToTuple
(
ops
);
});
return
Transform
(
configs
,
[](
const
auto
&
ops
)
{
return
ToTuple
(
ops
);
});
};
};
}
}
// takes in an operation instance and uses it to substitute the correct values into the template
template
<
class
T
>
template
<
class
T
>
static
std
::
string
ToTuple
(
const
T
&
ops
)
static
std
::
string
ToTuple
(
const
T
&
ops
)
{
{
...
@@ -31,6 +34,7 @@ struct Emitters
...
@@ -31,6 +34,7 @@ struct Emitters
return
"std::tuple<
\n
"
+
ck
::
host
::
JoinStrings
(
templates
,
",
\n
"
)
+
">"
;
return
"std::tuple<
\n
"
+
ck
::
host
::
JoinStrings
(
templates
,
",
\n
"
)
+
">"
;
}
}
// Join together all the strings in the map
std
::
string
Emit
(
const
std
::
string
&
name
)
{
return
ck
::
host
::
JoinStrings
(
m
.
at
(
name
)(),
"
\n
"
);
}
std
::
string
Emit
(
const
std
::
string
&
name
)
{
return
ck
::
host
::
JoinStrings
(
m
.
at
(
name
)(),
"
\n
"
);
}
std
::
vector
<
std
::
string
>
List
()
const
std
::
vector
<
std
::
string
>
List
()
const
...
@@ -43,9 +47,38 @@ int main(int argc, const char* argv[])
...
@@ -43,9 +47,38 @@ int main(int argc, const char* argv[])
{
{
std
::
string
prog
=
argv
[
0
];
std
::
string
prog
=
argv
[
0
];
std
::
vector
<
std
::
string
>
args
(
argv
+
1
,
argv
+
argc
);
std
::
vector
<
std
::
string
>
args
(
argv
+
1
,
argv
+
argc
);
// Specify problem type and problem size
ck
::
host
::
device_gemm_multiple_d
::
Problem
prob
;
prob
.
M
=
1024
;
prob
.
N
=
1024
;
prob
.
K
=
1024
;
// user provided fusion
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
R"(
struct Epilogue
{
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename D>
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
}
float alpha_;
float beta_;
};)"
;
// Load in operations into the Register
Emitters
e
;
Emitters
e
;
e
.
Register
<
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
>
(
e
.
Register
<
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
>
(
"DeviceGemmMultipleD_Xdl_CShuffle"
);
"DeviceGemmMultipleD_Xdl_CShuffle"
,
prologue
,
epilogue
);
if
(
args
.
empty
()
or
std
::
any_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
if
(
args
.
empty
()
or
std
::
any_of
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
==
"-h"
or
arg
==
"--help"
;
return
arg
==
"-h"
or
arg
==
"--help"
;
...
@@ -64,6 +97,7 @@ int main(int argc, const char* argv[])
...
@@ -64,6 +97,7 @@ int main(int argc, const char* argv[])
return
0
;
return
0
;
}
}
// print out all the instances for the operation that was chosen at the command line
for
(
auto
name
:
args
)
for
(
auto
name
:
args
)
std
::
cout
<<
e
.
Emit
(
name
)
<<
std
::
endl
;
std
::
cout
<<
e
.
Emit
(
name
)
<<
std
::
endl
;
...
...
codegen/include/ck/host/device_gemm_multiple_d.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
View file @
6bbb94f4
...
@@ -14,10 +14,15 @@ namespace ck {
...
@@ -14,10 +14,15 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
struct
Operation_Xdl_CShuffle
{
{
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
();
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
);
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
A
{};
TensorDesc
B
{};
TensorDesc
B
{};
DataType
acc
=
DataType
::
Float
;
DataType
acc
=
DataType
::
Float
;
...
@@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle
...
@@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
cde_elem_op
=
Bilinear
;
std
::
string
cde_elem_op
=
Bilinear
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDesc
tile_desc
{};
operation
::
TileDesc
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
Solution
ToSolution
()
const
;
};
};
...
...
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,11 +12,14 @@ namespace ck {
...
@@ -12,11 +12,14 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
// defines the problem specification for a GEMM operation
struct
Problem
struct
Problem
{
{
std
::
size_t
M
=
0
;
// dimensions for GEMM operation
std
::
size_t
N
=
0
;
std
::
size_t
M
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
// layouts for tensors
bool
TransA
=
false
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
bool
TransE
=
false
;
...
@@ -29,9 +32,13 @@ struct Problem
...
@@ -29,9 +32,13 @@ struct Problem
std
::
string
BElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
CDEElementOp
=
PassThrough
;
std
::
string
CDEElementOp
=
PassThrough
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
};
}
// namespace device_gemm_multiple_d
}
// namespace device_gemm_multiple_d
...
...
codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
namespace
ck
{
namespace
host
{
namespace
conv
{
// defines the values needed for an instance of forward convolution and functions to return
// (templated) instances
struct
Operation_Conv_Fwd_Xdl_Cshuffle
{
// returns a vector of instances given the fusion operations, uses default values for problem
// spec
static
std
::
vector
<
Operation_Conv_Fwd_Xdl_Cshuffle
>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, provided with a problem spec and fusion operations
static
std
::
vector
<
Operation_Conv_Fwd_Xdl_Cshuffle
>
CreateOperations
(
const
Problem_Conv_Fwd
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
std
::
size_t
NumDim
;
TensorDesc
A
{};
TensorDesc
B
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
vector
<
TensorDesc
>
Ds
=
{};
TensorDesc
E
{};
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
cde_elem_op
=
PassThrough
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
conv_specialization
=
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::MNKPadding"
;
// tuning parameters
operation
::
TileDesc
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
// functions to update fusion operations if they are provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace conv
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
conv
{
// defines the problem specification for a forward convolution operation
struct
Problem_Conv_Fwd
{
std
::
size_t
NumDim
=
0
;
// size of a forward convolution operation
std
::
size_t
G
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
C
=
0
;
std
::
size_t
Hi
=
0
;
std
::
size_t
Wi
=
0
;
std
::
size_t
Ho
=
0
;
std
::
size_t
Wo
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
Y
=
0
;
std
::
size_t
X
=
0
;
Layout
ALayout
=
Layout
::
NHWGC
;
Layout
BLayout
=
Layout
::
GKYXC
;
Layout
ELayout
=
Layout
::
NHWGK
;
std
::
vector
<
Layout
>
DsLayout
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace conv
}
// namespace host
}
// namespace ck
codegen/include/ck/host/headers.hpp
View file @
6bbb94f4
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include <string_view>
#include <utility>
#include <utility>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
...
codegen/include/ck/host/operation/gemm.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
codegen/include/ck/host/stringutils.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
codegen/include/ck/host/types.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
// holds the templated instance, substitues values into template from instancess
struct
Solution
struct
Solution
{
{
...
@@ -33,6 +34,7 @@ struct Solution
...
@@ -33,6 +34,7 @@ struct Solution
std
::
unordered_map
<
std
::
string
,
std
::
string
>
template_values
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
template_values
;
};
};
// supported data types
enum
class
DataType
enum
class
DataType
{
{
Half
,
Half
,
...
@@ -40,22 +42,28 @@ enum class DataType
...
@@ -40,22 +42,28 @@ enum class DataType
Int8
,
Int8
,
Int32
Int32
};
};
std
::
string
ToString
(
DataType
dt
);
std
::
string
ToString
(
DataType
dt
);
// supported layouts: gemm and fwd conv
enum
class
Layout
enum
class
Layout
{
{
Row
,
Row
,
Column
Column
,
GKYXC
,
GKCYX
,
GNHWK
,
GNHWC
,
NHWGC
,
NHWGK
};
};
std
::
string
ToString
(
Layout
dl
);
std
::
string
ToString
(
Layout
dl
);
Layout
ToLayout
(
bool
Trans
);
// returns the layout for gemm
// supported GEMM types
enum
class
GemmType
enum
class
GemmType
{
{
Default
Default
};
};
std
::
string
ToString
(
GemmType
gt
);
std
::
string
ToString
(
GemmType
gt
);
struct
TensorDesc
struct
TensorDesc
...
...
codegen/include/ck/host/utils.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <cstdint>
#include <cstdint>
#include <unordered_set>
#include <unordered_set>
#include <numeric>
#include <iterator>
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
...
@@ -12,6 +14,5 @@ namespace host {
...
@@ -12,6 +14,5 @@ namespace host {
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
();
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
();
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/src/device_gemm_multiple_d.cpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
...
@@ -11,23 +11,28 @@ namespace ck {
...
@@ -11,23 +11,28 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
std
::
string
Problem
::
GetIncludeHeader
()
const
{
{
return
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
return
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
}
}
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
)
const
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
return
{};
auto
ops
=
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
);
auto
ops
=
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
return
op
.
ToSolution
();
// template instance with correct values
});
});
return
result
;
return
result
;
}
}
}
// namespace device_gemm_multiple_d
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
\ No newline at end of file
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
6bbb94f4
...
@@ -10,6 +10,7 @@ namespace ck {
...
@@ -10,6 +10,7 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
// calculate appropriate Gemm Specification based on input tensor dimensions
static
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
static
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
k
,
...
@@ -30,9 +31,40 @@ static std::string GetGemmSpec(const std::size_t m,
...
@@ -30,9 +31,40 @@ static std::string GetGemmSpec(const std::size_t m,
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
}
// function to update prologue/epilogue with user provided operation
void
Operation_Xdl_CShuffle
::
update_prologue
(
const
std
::
string
&
prologue
)
{
if
(
!
prologue
.
empty
())
{
this
->
prologue
=
prologue
;
this
->
cde_elem_op
=
"CDEElementOp"
;
}
else
{
this
->
prologue
=
""
;
}
}
void
Operation_Xdl_CShuffle
::
update_epilogue
(
const
std
::
string
&
epilogue
)
{
if
(
!
epilogue
.
empty
())
{
this
->
epilogue
=
epilogue
;
this
->
cde_elem_op
=
"CDEElementOp"
;
}
else
{
this
->
epilogue
=
""
;
}
}
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
const
Problem
&
prob
)
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
{
{
std
::
vector
<
Operation_Xdl_CShuffle
>
result
;
std
::
vector
<
Operation_Xdl_CShuffle
>
result
;
...
@@ -155,6 +187,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
...
@@ -155,6 +187,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
// clang-format on
// clang-format on
};
};
// choose correct arrangement of tuning parameters based on the layout of each tensor
const
auto
a_block_descriptions
=
const
auto
a_block_descriptions
=
prob
.
TransA
?
a_block_descriptions_colmajor
:
a_block_descriptions_rowmajor
;
prob
.
TransA
?
a_block_descriptions_colmajor
:
a_block_descriptions_rowmajor
;
const
auto
b_block_descriptions
=
const
auto
b_block_descriptions
=
...
@@ -165,6 +198,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
...
@@ -165,6 +198,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
// Put all values together into a single operation > store into the result vector
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
{
Operation_Xdl_CShuffle
x
;
Operation_Xdl_CShuffle
x
;
...
@@ -188,12 +222,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
...
@@ -188,12 +222,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
x
.
tile_desc
.
m_per_block
,
x
.
tile_desc
.
m_per_block
,
x
.
tile_desc
.
n_per_block
,
x
.
tile_desc
.
n_per_block
,
x
.
tile_desc
.
k_per_block
);
x
.
tile_desc
.
k_per_block
);
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
result
.
push_back
(
x
);
}
}
return
result
;
return
result
;
}
}
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
Operation_Xdl_CShuffle
::
CreateOperations
()
// set up instances when not provided with a problem specification, use default operation values and
// all possible layout combinations
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
Operation_Xdl_CShuffle
::
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
{
{
std
::
vector
<
Problem
>
problems
;
std
::
vector
<
Problem
>
problems
;
for
(
bool
TransA
:
{
true
,
false
})
for
(
bool
TransA
:
{
true
,
false
})
...
@@ -204,7 +243,8 @@ std::vector<std::vector<Operation_Xdl_CShuffle>> Operation_Xdl_CShuffle::CreateO
...
@@ -204,7 +243,8 @@ std::vector<std::vector<Operation_Xdl_CShuffle>> Operation_Xdl_CShuffle::CreateO
prob
.
TransB
=
TransB
;
prob
.
TransB
=
TransB
;
problems
.
push_back
(
prob
);
problems
.
push_back
(
prob
);
}
}
return
Transform
(
problems
,
[](
const
Problem
&
p
)
{
return
CreateOperations
(
p
);
});
return
Transform
(
problems
,
[
&
](
const
Problem
&
p
)
{
return
CreateOperations
(
p
,
prologue
,
epilogue
);
});
}
}
static
const
char
*
const
DeviceGemmMultipleD_Xdl_CShuffleTemplate
=
static
const
char
*
const
DeviceGemmMultipleD_Xdl_CShuffleTemplate
=
...
@@ -224,9 +264,20 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
...
@@ -224,9 +264,20 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>"
;
"${CDEBlockTransferScalarPerVector_NPerBlock}>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
{
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
{
"name"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
ak1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
bk1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_Xdl_per_wave
)},
{
"LayoutA"
,
ToString
(
this
->
A
.
layout
)},
{
"LayoutA"
,
ToString
(
this
->
A
.
layout
)},
{
"LayoutB"
,
ToString
(
this
->
B
.
layout
)},
{
"LayoutB"
,
ToString
(
this
->
B
.
layout
)},
{
"LayoutDs"
,
{
"LayoutDs"
,
...
...
codegen/src/device_grouped_conv_fwd_multiple_abd.cpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
#include <iostream>
namespace
ck
{
namespace
host
{
namespace
conv
{
// return the relevant device op file based on the operation
// NOTE: this is a modified version of the original CK file that calls the kernel from a device
// function and makes the Argument class accessible on the device
std
::
string
Problem_Conv_Fwd
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/"
"codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
;
}
// return vector of forward convolution instances when provided with a problem instance
std
::
vector
<
Solution
>
Problem_Conv_Fwd
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
conv
::
Operation_Conv_Fwd_Xdl_Cshuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
});
return
result
;
}
}
// namespace conv
}
// namespace host
}
// namespace ck
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment