Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2a30cfdd
Unverified
Commit
2a30cfdd
authored
Feb 12, 2025
by
arai713
Committed by
GitHub
Feb 12, 2025
Browse files
Merge branch 'develop' into codegen-enable-hiprtc
parents
9533a172
78195ccc
Changes
740
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
714 additions
and
36 deletions
+714
-36
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+8
-0
client_example/README.md
client_example/README.md
+2
-0
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
codegen/README.md
codegen/README.md
+2
-0
codegen/driver/main.cpp
codegen/driver/main.cpp
+2
-0
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
...de/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
+61
-0
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
...lude/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
+47
-0
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+2
-0
codegen/include/ck/host/operation/gemm.hpp
codegen/include/ck/host/operation/gemm.hpp
+20
-0
codegen/include/ck/host/types.hpp
codegen/include/ck/host/types.hpp
+15
-0
codegen/src/device_batched_gemm_softmax_gemm.cpp
codegen/src/device_batched_gemm_softmax_gemm.cpp
+38
-0
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
...vice_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
+408
-0
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
...gen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+67
-35
codegen/src/headers.cpp
codegen/src/headers.cpp
+3
-0
codegen/src/types.cpp
codegen/src/types.cpp
+23
-0
codegen/test/gemm_multiple_d.cpp
codegen/test/gemm_multiple_d.cpp
+3
-0
codegen/test/grouped_conv_fwd_multiple_d_v1.cpp
codegen/test/grouped_conv_fwd_multiple_d_v1.cpp
+3
-0
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
+3
-0
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
+3
-0
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
+3
-0
No files found.
client_example/CMakeLists.txt
View file @
2a30cfdd
...
@@ -56,6 +56,14 @@ if (GPU_TARGETS)
...
@@ -56,6 +56,14 @@ if (GPU_TARGETS)
add_definitions
(
-DCK_USE_WMMA
)
add_definitions
(
-DCK_USE_WMMA
)
set
(
CK_USE_WMMA
"ON"
)
set
(
CK_USE_WMMA
"ON"
)
endif
()
endif
()
if
(
GPU_TARGETS MATCHES
"gfx12"
OR GPU_TARGETS MATCHES
"gfx950"
)
add_definitions
(
-DCK_USE_OCP_FP8
)
set
(
CK_USE_OCP_FP8
"ON"
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx94"
)
add_definitions
(
-DCK_USE_FNUZ_FP8
)
set
(
CK_USE_FNUZ_FP8
"ON"
)
endif
()
else
()
else
()
add_definitions
(
-DCK_USE_WMMA -DCK_USE_XDL
)
add_definitions
(
-DCK_USE_WMMA -DCK_USE_XDL
)
set
(
CK_USE_XDL
"ON"
)
set
(
CK_USE_XDL
"ON"
)
...
...
client_example/README.md
View file @
2a30cfdd
[
Back to the main page
](
../README.md
)
# Composable Kernel client examples
##
##
Client application links to CK library, and therefore CK library needs to be installed before building client applications.
Client application links to CK library, and therefore CK library needs to be installed before building client applications.
...
...
cmake/EnableCompilerWarnings.cmake
View file @
2a30cfdd
...
@@ -66,7 +66,7 @@ else()
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
...
...
codegen/README.md
0 → 100644
View file @
2a30cfdd
[
Back to the main page
](
../README.md
)
# Composable Kernel codegen
\ No newline at end of file
codegen/driver/main.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
...
...
codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines all values need for an instance of fwd conv
struct
Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static
std
::
vector
<
std
::
vector
<
Operation_Xdl_CShuffle
>>
CreateOperations
(
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
// returns a vector of instances, given a problem spec and fusion operators
static
std
::
vector
<
Operation_Xdl_CShuffle
>
CreateOperations
(
const
Problem
&
prob
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
);
TensorDesc
A
{};
TensorDesc
B
{};
TensorDesc
B1
{};
TensorDesc
C
{};
DataType
acc
=
DataType
::
Float
;
DataType
cs_type
=
DataType
::
Half
;
std
::
string
a_elem_op
=
PassThrough
;
std
::
string
b_elem_op
=
PassThrough
;
std
::
string
b1_elem_op
=
PassThrough
;
std
::
string
c_elem_op
=
PassThrough
;
std
::
string
acc_elem_op
=
Scale
;
std
::
string
prologue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
gemm_specialization
=
"ck::tensor_operation::device::GemmSpecialization::Default"
;
// tuning parameters
operation
::
TileDescGemmGemm
tile_desc
{};
operation
::
BlockTransferDesc
a_block_transfer
{};
operation
::
BlockTransferDesc
b0_block_transfer
{};
operation
::
BlockTransferDesc
b1_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
bool
mask_out_upper_triangle
=
false
;
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_epilogue
(
const
std
::
string
&
epilogue
);
/**constexpr**/
bool
IsSupported
(
std
::
size_t
MRaw_
,
std
::
size_t
NRaw_
,
std
::
size_t
KRaw_
,
std
::
size_t
Gemm1NRaw_
);
// returns a templated instance
Solution
ToSolution
()
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// defines the problem specification for a GEMM operation
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
B1ElementOp
=
PassThrough
;
std
::
string
CElementOp
=
PassThrough
;
std
::
string
AccElementOp
=
Scale
;
// returns the correct device op file for the operation
std
::
string
GetIncludeHeader
()
const
;
// returns a list of instances based on the problem spec and provided fusion operations
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
View file @
2a30cfdd
...
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
...
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
BlockTransferDesc
b_block_transfer
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CShuffleDesc
cshuffle
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
operation
::
CBlockTransferDesc
c_block_transfer
{};
LoopScheduler
loop_scheduler
{};
PipelineVersion
pipeline_version
{};
// functions to update fusion operators if provided
// functions to update fusion operators if provided
void
update_prologue
(
const
std
::
string
&
prologue
);
void
update_prologue
(
const
std
::
string
&
prologue
);
...
...
codegen/include/ck/host/operation/gemm.hpp
View file @
2a30cfdd
...
@@ -23,6 +23,26 @@ struct TileDesc
...
@@ -23,6 +23,26 @@ struct TileDesc
int
n_Xdl_per_wave
=
0
;
int
n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
};
struct
TileDescGemmGemm
{
int
block_size
=
0
;
int
gemm01_m_per_block
=
0
;
int
gemm0_n_per_block
=
0
;
int
gemm0_k_per_block
=
0
;
int
gemm1_n_per_block
=
0
;
int
gemm1_k_per_block
=
0
;
int
ak1
=
0
;
int
bk1
=
0
;
int
b1k1
=
0
;
int
m_per_XDL
=
0
;
int
n_per_XDL
=
0
;
int
gemm0_m_Xdl_per_wave
=
0
;
int
gemm0_n_Xdl_per_wave
=
0
;
int
gemm1_n_Xdl_per_wave
=
0
;
int
num_gemmk_prefetch_stage
=
0
;
};
struct
BlockTransferDesc
struct
BlockTransferDesc
{
{
std
::
string
thread_cluster_length
=
""
;
std
::
string
thread_cluster_length
=
""
;
...
...
codegen/include/ck/host/types.hpp
View file @
2a30cfdd
...
@@ -66,6 +66,20 @@ enum class GemmType
...
@@ -66,6 +66,20 @@ enum class GemmType
};
};
std
::
string
ToString
(
GemmType
gt
);
std
::
string
ToString
(
GemmType
gt
);
enum
class
LoopScheduler
{
Default
,
Interwave
,
};
std
::
string
ToString
(
LoopScheduler
ls
);
enum
class
PipelineVersion
{
v1
,
v2
};
std
::
string
ToString
(
PipelineVersion
pv
);
struct
TensorDesc
struct
TensorDesc
{
{
DataType
element
;
DataType
element
;
...
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
...
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Bilinear
=
"ck::tensor_operation::element_wise::Bilinear"
;
constexpr
const
char
*
Scale
=
"ck::tensor_operation::element_wise::Scale"
;
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm.cpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
// return the relevant device op file based on the operation
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
"ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
;
}
// returns templated instances when provided with a problem specification
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
,
const
std
::
string
&
prologue
,
const
std
::
string
&
epilogue
)
const
{
if
(
get_xdlop_archs
().
count
(
arch
)
==
0
)
return
{};
auto
ops
=
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Operation_Xdl_CShuffle
::
CreateOperations
(
*
this
,
prologue
,
epilogue
);
// obtains vector of instances
std
::
vector
<
Solution
>
result
;
std
::
transform
(
ops
.
begin
(),
ops
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
op
)
{
return
op
.
ToSolution
();
// template instance with correct values
});
return
result
;
}
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
0 → 100644
View file @
2a30cfdd
This diff is collapsed.
Click to expand it.
codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
2a30cfdd
...
@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
...
@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
// accounts for all possible combinations of Row/Col major
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
static
Layout
ToLayout
(
bool
Trans
)
{
return
Trans
?
Layout
::
Column
:
Layout
::
Row
;
}
// clang-format off
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// clang-format on
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
// instances
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
std
::
vector
<
Operation_Xdl_CShuffle
>
Operation_Xdl_CShuffle
::
CreateOperations
(
...
@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
{
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
},
// Irregular tile
{
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
// | | | | | | |
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
// clang-format on
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
};
};
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
std
::
vector
<
operation
::
BlockTransferDesc
>
b_block_descriptions_rowmajor
=
{
...
@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
{
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
},
// Irregular tile
{
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
16
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
{
S
<
1
,
32
,
1
,
8
>
,
8
},
// Irregular tile
{
S
<
1
,
16
,
1
,
4
>
,
1
},
// clang-format on
// clang-format on
};
};
...
@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
cshuffle_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
assert
(
tile_descriptions
.
size
()
==
c_block_descriptions
.
size
());
// Put all values together into a single operation > store into the result vector
const
std
::
vector
<
std
::
tuple
<
LoopScheduler
,
PipelineVersion
>>
scheduler_pipeline_descriptions
=
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
{
{
LoopScheduler
::
Default
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Interwave
,
PipelineVersion
::
v1
},
{
LoopScheduler
::
Default
,
PipelineVersion
::
v2
},
};
for
(
auto
[
loop_scheduler
,
pipeline_version
]
:
scheduler_pipeline_descriptions
)
{
{
Operation_Xdl_CShuffle
x
;
// Put all values together into a single operation > store into the result vector
x
.
tile_desc
=
tile_descriptions
[
i
];
for
(
std
::
size_t
i
=
0
;
i
<
tile_descriptions
.
size
();
i
++
)
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
{
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
Operation_Xdl_CShuffle
x
;
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
tile_desc
=
tile_descriptions
[
i
];
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
a_block_transfer
=
a_block_descriptions
[
i
];
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
x
.
b_block_transfer
=
b_block_descriptions
[
i
];
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
x
.
cshuffle
=
cshuffle_descriptions
[
i
];
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
c_block_transfer
=
c_block_descriptions
[
i
];
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
A
=
TensorDesc
{
prob
.
ADataType
,
ToLayout
(
prob
.
TransA
)};
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
x
.
B
=
TensorDesc
{
prob
.
BDataType
,
ToLayout
(
prob
.
TransB
)};
});
x
.
E
=
TensorDesc
{
prob
.
EDataType
,
ToLayout
(
prob
.
TransE
)};
x
.
a_elem_op
=
prob
.
AElementOp
;
x
.
Ds
=
Transform
(
prob
.
DsTrans
,
prob
.
DsDataType
,
[](
auto
trans
,
auto
dt
)
{
x
.
b_elem_op
=
prob
.
BElementOp
;
return
TensorDesc
{
dt
,
ToLayout
(
trans
)};
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
});
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
x
.
a_elem_op
=
prob
.
AElementOp
;
prob
.
N
,
x
.
b_elem_op
=
prob
.
BElementOp
;
prob
.
K
,
x
.
cde_elem_op
=
prob
.
CDEElementOp
;
x
.
tile_desc
.
m_per_block
,
x
.
gemm_specialization
=
GetGemmSpec
(
prob
.
M
,
x
.
tile_desc
.
n_per_block
,
prob
.
N
,
x
.
tile_desc
.
k_per_block
);
prob
.
K
,
x
.
update_prologue
(
prologue
);
x
.
tile_desc
.
m_per_block
,
x
.
update_epilogue
(
epilogue
);
x
.
tile_desc
.
n_per_block
,
result
.
push_back
(
x
);
x
.
tile_desc
.
k_per_block
);
x
.
loop_scheduler
=
loop_scheduler
;
x
.
pipeline_version
=
pipeline_version
;
x
.
update_prologue
(
prologue
);
x
.
update_epilogue
(
epilogue
);
result
.
push_back
(
x
);
}
}
}
return
result
;
return
result
;
}
}
...
@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
...
@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>"
;
"${CDEBlockTransferScalarPerVector_NPerBlock}
, ${LoopScheduler}, ${PipelineVersion}
>"
;
// use hardcoded instances from vector of operations to substitute values into instance template
// use hardcoded instances from vector of operations to substitute values into instance template
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
Solution
Operation_Xdl_CShuffle
::
ToSolution
()
const
...
@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
this
->
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
},
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
{
"CDEBlockTransferScalarPerVector_NPerBlock"
,
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
std
::
to_string
(
this
->
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
)},
{
"LoopScheduler"
,
ToString
(
this
->
loop_scheduler
)},
{
"PipelineVersion"
,
ToString
(
this
->
pipeline_version
)},
};
};
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
return
Solution
{
InterpolateString
(
DeviceGemmMultipleD_Xdl_CShuffleTemplate
,
values
),
...
...
codegen/src/headers.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
#include "ck_headers.hpp"
#include "ck_headers.hpp"
...
...
codegen/src/types.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/types.hpp"
#include "ck/host/types.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/stringutils.hpp"
#include <algorithm>
#include <algorithm>
...
@@ -56,6 +59,26 @@ std::string ToString(GemmType gt)
...
@@ -56,6 +59,26 @@ std::string ToString(GemmType gt)
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
throw
std
::
runtime_error
(
"Incorrect gemm type"
);
}
}
std
::
string
ToString
(
LoopScheduler
ls
)
{
switch
(
ls
)
{
case
LoopScheduler
::
Default
:
return
"ck::LoopScheduler::Default"
;
case
LoopScheduler
::
Interwave
:
return
"ck::LoopScheduler::Interwave"
;
}
throw
std
::
runtime_error
(
"Incorrect LoopScheduler type"
);
}
std
::
string
ToString
(
PipelineVersion
pv
)
{
switch
(
pv
)
{
case
PipelineVersion
::
v1
:
return
"ck::PipelineVersion::v1"
;
case
PipelineVersion
::
v2
:
return
"ck::PipelineVersion::v2"
;
}
throw
std
::
runtime_error
(
"Incorrect PipelineVersion type"
);
}
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
std
::
string
SequenceStr
(
const
std
::
vector
<
int
>&
v
)
{
{
return
"ck::Sequence<"
+
return
"ck::Sequence<"
+
...
...
codegen/test/gemm_multiple_d.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
...
...
codegen/test/grouped_conv_fwd_multiple_d_v1.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
...
...
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
...
...
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
...
...
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
...
...
Prev
1
2
3
4
5
6
…
37
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment