Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
986182fc
Commit
986182fc
authored
Sep 27, 2023
by
Umang Yadav
Browse files
Merge branch 'migraphx' into migx-jit-lib-hiprtc
parents
3ca84d92
11cab2d5
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
960 additions
and
133 deletions
+960
-133
.gitignore
.gitignore
+3
-0
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+351
-19
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+3
-1
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+1
-1
library/src/jit_library/CMakeLists.txt
library/src/jit_library/CMakeLists.txt
+1
-0
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
...rary/include/ck/host/device_batched_gemm_softmax_gemm.hpp
+109
-0
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
.../src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
+114
-0
library/src/jit_library/util/file_templates.py
library/src/jit_library/util/file_templates.py
+177
-0
library/src/jit_library/util/make_instance_strings.py
library/src/jit_library/util/make_instance_strings.py
+200
-111
No files found.
.gitignore
View file @
986182fc
...
...
@@ -63,3 +63,6 @@ _templates/
_toc.yml
docBin/
_doxygen/
# pycache
__pycache__/
cmake/EnableCompilerWarnings.cmake
View file @
986182fc
...
...
@@ -65,8 +65,8 @@ else()
-Wuninitialized
-Wunreachable-code
-Wunused
-Werror
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
986182fc
...
...
@@ -611,6 +611,95 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
...
...
@@ -625,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
KRaw
=
arg
.
raw_lengths_m_n_k_o_
[
2
];
const
auto
Gemm1NRaw
=
arg
.
raw_lengths_m_n_k_o_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
?
KRaw
:
MRaw
;
const
auto
b_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
?
NRaw
:
KRaw
;
const
auto
b1_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
?
Gemm1NRaw
:
NRaw
;
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
);
}
// polymorphic
...
...
@@ -766,6 +838,266 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
}
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
{
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
assert
(
desc
.
is_valid
);
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
986182fc
...
...
@@ -581,7 +581,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
986182fc
...
...
@@ -53,7 +53,7 @@ struct MaskOutUpperTrianglePredicate
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
{
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
constexpr
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
...
...
library/src/jit_library/CMakeLists.txt
View file @
986182fc
...
...
@@ -13,6 +13,7 @@ execute_process(
)
add_library
(
jit_library STATIC
src/device_batched_gemm_softmax_gemm.cpp
src/device_gemm_multiple_d.cpp
src/common.cpp
)
...
...
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
0 → 100644
View file @
986182fc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B1ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
AccElementOp
=
"ck::tensor_operation::element_wise::Scale"
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
static
const
std
::
size_t
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx
=
0
;
static
const
std
::
size_t
ALayout_idx
=
1
;
static
const
std
::
size_t
B0Layout_idx
=
2
;
static
const
std
::
size_t
B1Layout_idx
=
3
;
static
const
std
::
size_t
CLayout_idx
=
4
;
static
const
std
::
size_t
ADataType_idx
=
5
;
static
const
std
::
size_t
B0DataType_idx
=
6
;
static
const
std
::
size_t
B1DataType_idx
=
7
;
static
const
std
::
size_t
CDataType_idx
=
8
;
static
const
std
::
size_t
AccDataType_idx
=
9
;
static
const
std
::
size_t
CShuffleDataType_idx
=
10
;
static
const
std
::
size_t
AElementwiseOperation_idx
=
11
;
static
const
std
::
size_t
B0ElementwiseOperation_idx
=
12
;
static
const
std
::
size_t
Acc0ElementwiseOperation_idx
=
13
;
static
const
std
::
size_t
B1ElementwiseOperation_idx
=
14
;
static
const
std
::
size_t
CElementwiseOperation_idx
=
15
;
static
const
std
::
size_t
GEMMSpecialization_idx
=
16
;
static
const
std
::
size_t
NumGemmKPrefetchStage_idx
=
17
;
static
const
std
::
size_t
BlockSize_idx
=
18
;
static
const
std
::
size_t
Gemm01MPerBlock_idx
=
19
;
static
const
std
::
size_t
Gemm0NPerBlock_idx
=
20
;
static
const
std
::
size_t
Gemm0KPerBlock_idx
=
21
;
static
const
std
::
size_t
Gemm1NPerBlock_idx
=
22
;
static
const
std
::
size_t
Gemm1KPerBlock_idx
=
23
;
static
const
std
::
size_t
AK1_idx
=
24
;
static
const
std
::
size_t
BK1_idx
=
25
;
static
const
std
::
size_t
B1K1_idx
=
26
;
static
const
std
::
size_t
MPerXDL_idx
=
27
;
static
const
std
::
size_t
NPerXDL_idx
=
28
;
static
const
std
::
size_t
Gemm0MXdlPerWave_idx
=
29
;
static
const
std
::
size_t
Gemm0NXdlPerWave_idx
=
30
;
static
const
std
::
size_t
Gemm1NXdlPerWave_idx
=
31
;
static
const
std
::
size_t
ABlockTransferThreadClusterLengths_K0_M_K1_idx
=
32
;
static
const
std
::
size_t
ABlockTransferThreadClusterArrangeOrder_idx
=
33
;
static
const
std
::
size_t
ABlockTransferSrcAccessOrder_idx
=
34
;
static
const
std
::
size_t
ABlockTransferSrcVectorDim_idx
=
35
;
static
const
std
::
size_t
ABlockTransferSrcScalarPerVector_idx
=
36
;
static
const
std
::
size_t
ABlockTransferDstScalarPerVector_K1_idx
=
37
;
static
const
std
::
size_t
ABlockLdsAddExtraM_idx
=
38
;
static
const
std
::
size_t
B0BlockTransferThreadClusterLengths_K0_N_K1_idx
=
39
;
static
const
std
::
size_t
B0BlockTransferThreadClusterArrangeOrder_idx
=
40
;
static
const
std
::
size_t
B0BlockTransferSrcAccessOrder_idx
=
41
;
static
const
std
::
size_t
B0BlockTransferSrcVectorDim_idx
=
42
;
static
const
std
::
size_t
B0BlockTransferSrcScalarPerVector_idx
=
43
;
static
const
std
::
size_t
B0BlockTransferDstScalarPerVector_K1_idx
=
44
;
static
const
std
::
size_t
B0BlockLdsAddExtraN_idx
=
45
;
static
const
std
::
size_t
B1BlockTransferThreadClusterLengths_K0_N_K1_idx
=
46
;
static
const
std
::
size_t
B1BlockTransferThreadClusterArrangeOrder_idx
=
47
;
static
const
std
::
size_t
B1BlockTransferSrcAccessOrder_idx
=
48
;
static
const
std
::
size_t
B1BlockTransferSrcVectorDim_idx
=
49
;
static
const
std
::
size_t
B1BlockTransferSrcScalarPerVector_idx
=
50
;
static
const
std
::
size_t
B1BlockTransferDstScalarPerVector_K1_idx
=
51
;
static
const
std
::
size_t
B1BlockLdsAddExtraN_idx
=
52
;
static
const
std
::
size_t
CShuffleMXdlPerWavePerShuffle_idx
=
53
;
static
const
std
::
size_t
CShuffleNXdlPerWavePerShuffle_idx
=
54
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
0 → 100644
View file @
986182fc
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp"
#include "batched_gemm_softmax_gemm_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
n1
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
,
const
std
::
size_t
n1_per_block
)
{
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
if
(
integer_divide_ceil
(
n1
,
n1_per_block
)
*
n1_per_block
-
n1
!=
0
)
spec
+=
"O"
;
if
(
spec
==
""
)
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx90a"
,
"gfx908"
,
"gfx940"
};
return
supported_archs
;
}
std
::
vector
<
std
::
string
>
Problem
::
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
ck
::
host
::
instance
::
batched_gemm_softmax_gemm_instances
all_instances
{};
instances
=
all_instances
.
get_instances
();
}
return
instances
;
}
Solution
Problem
::
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
{
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
params
[
AElementwiseOperation_idx
]
=
AElementOp
;
params
[
B0ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B1ElementwiseOperation_idx
]
=
BElementOp
;
params
[
CElementwiseOperation_idx
]
=
CElementOp
;
params
[
Acc0ElementwiseOperation_idx
]
=
AccElementOp
;
auto
block_size_str
=
params
[
BlockSize_idx
];
auto
m_per_block_str
=
params
[
Gemm01MPerBlock_idx
];
auto
n_per_block_str
=
params
[
Gemm0NPerBlock_idx
];
auto
k_per_block_str
=
params
[
Gemm0KPerBlock_idx
];
auto
n1_per_block_str
=
params
[
Gemm1NPerBlock_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
}
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
ck
::
host
::
instance
::
batched_gemm_softmax_gemm_instances
{}.
get_include_header
();
}
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
return
solutions
;
}
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
library/src/jit_library/util/file_templates.py
0 → 100644
View file @
986182fc
out_file_with_quant
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
out_file_no_quant
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {instances_name} =
{{
{instances}
}};
static auto get_instances()
{{
return {instances_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
def
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
col_row_instances
,
col_col_name
,
col_col_instances
,
row_row_name
,
row_row_instances
,
row_col_name
,
row_col_instances
,
int8_col_row_name
,
int8_col_row_instances
,
int8_col_col_name
,
int8_col_col_instances
,
int8_row_row_name
,
int8_row_row_instances
,
int8_row_col_name
,
int8_row_col_instances
,
include_header
):
return
out_file_with_quant
.
format
(
op_name
=
op_name
,
col_row_name
=
col_row_name
,
col_row_instances
=
col_row_instances
,
col_col_name
=
col_col_name
,
col_col_instances
=
col_col_instances
,
row_row_name
=
row_row_name
,
row_row_instances
=
row_row_instances
,
row_col_name
=
row_col_name
,
row_col_instances
=
row_col_instances
,
int8_col_row_name
=
int8_col_row_name
,
int8_col_row_instances
=
int8_col_row_instances
,
int8_col_col_name
=
int8_col_col_name
,
int8_col_col_instances
=
int8_col_col_instances
,
int8_row_row_name
=
int8_row_row_name
,
int8_row_row_instances
=
int8_row_row_instances
,
int8_row_col_name
=
int8_row_col_name
,
int8_row_col_instances
=
int8_row_col_instances
,
include_header
=
include_header
)
def
get_device_gemm_softmax_gemm_file
(
op_name
,
instances_name
,
instances
,
include_header
):
return
out_file_no_quant
.
format
(
op_name
=
op_name
,
instances_name
=
instances_name
,
instances
=
instances
,
include_header
=
include_header
)
library/src/jit_library/util/make_instance_strings.py
View file @
986182fc
import
argparse
,
re
,
json
,
os
,
sys
out_file
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
import
argparse
,
re
,
json
,
os
,
sys
,
file_templates
def
strip_sequences
(
str
):
matches
=
re
.
findall
(
r
'S<\d+(?:,\s*\d+)*>'
,
str
)
matches
=
re
.
findall
(
r
'S<\
s*\
d+(?:,\s*\d+)*>'
,
str
)
for
match
in
matches
:
str
=
str
.
replace
(
match
,
match
.
replace
(
' '
,
''
))
str
=
str
.
replace
(
'S<'
,
"ck::Sequence<"
)
...
...
@@ -251,27 +161,206 @@ def parse_instances(source, out_dir):
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
out_file
.
format
(
op_name
=
op_name
,
col_row_name
=
col_row_name
,
col_row_instances
=
"
\n
"
.
join
(
col_row_instances
),
col_col_name
=
col_col_name
,
col_col_instances
=
"
\n
"
.
join
(
col_col_instances
),
row_row_name
=
row_row_name
,
row_row_instances
=
"
\n
"
.
join
(
row_row_instances
),
row_col_name
=
row_col_name
,
row_col_instances
=
"
\n
"
.
join
(
row_col_instances
),
int8_col_row_name
=
int8_instances
[
"col_row_name"
],
int8_col_row_instances
=
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_col_col_name
=
int8_instances
[
"col_col_name"
],
int8_col_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_row_row_name
=
int8_instances
[
"row_row_name"
],
int8_row_row_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_row_col_name
=
int8_instances
[
"row_col_name"
],
int8_row_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
=
include_header
))
f
.
write
(
file_templates
.
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
"
\n
"
.
join
(
col_row_instances
),
col_col_name
,
"
\n
"
.
join
(
col_col_instances
),
row_row_name
,
"
\n
"
.
join
(
row_row_instances
),
row_col_name
,
"
\n
"
.
join
(
row_col_instances
),
int8_instances
[
"col_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_instances
[
"col_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_instances
[
"row_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_instances
[
"row_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
))
def
parse_device_gemm_multiple_d_instances
(
source
,
out_dir
):
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Empty_Tuple"
:
"ck::Tuple<>"
,
"LoopScheduler"
:
"ck::LoopScheduler"
,
"PipelineVersion"
:
"ck::PipelineVersion"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"OutElementOp"
:
"PassThrough"
}
device_ops
=
{
"gemm_add_add_fastgelu"
:
"DeviceGemmMultipleD_Xdl_CShuffle"
,
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
op_name
not
in
device_ops
:
continue
col_row_name
=
""
col_col_name
=
""
row_row_name
=
""
row_col_name
=
""
row_row_instances
=
[]
col_row_instances
=
[]
row_col_instances
=
[]
col_col_instances
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
is_row_row
=
bool
(
re
.
search
(
".*mk.*kn.*"
,
file_name
))
is_col_row
=
bool
(
re
.
search
(
".*km.*kn.*"
,
file_name
))
is_row_col
=
bool
(
re
.
search
(
".*mk.*nk.*"
,
file_name
))
is_col_col
=
bool
(
re
.
search
(
".*km.*nk.*"
,
file_name
))
if
is_row_row
:
row_row_name
=
file_name
[:
-
4
]
if
is_col_row
:
col_row_name
=
file_name
[:
-
4
]
if
is_row_col
:
row_col_name
=
file_name
[:
-
4
]
if
is_col_col
:
col_col_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
instances_list
.
append
(
new_line
)
instances_list
[
-
1
]
=
instances_list
[
-
1
][:
-
1
]
if
is_row_row
:
row_row_instances
=
instances_list
if
is_col_row
:
col_row_instances
=
instances_list
if
is_row_col
:
row_col_instances
=
instances_list
if
is_col_col
:
col_col_instances
=
instances_list
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
"
\n
"
.
join
(
col_row_instances
),
col_col_name
,
"
\n
"
.
join
(
col_col_instances
),
row_row_name
,
"
\n
"
.
join
(
row_row_instances
),
row_col_name
,
"
\n
"
.
join
(
row_col_instances
),
int8_instances
[
"col_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_instances
[
"col_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_instances
[
"row_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_instances
[
"row_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
))
def
parse_param_names
(
file
):
param_names
=
[]
for
line
in
file
:
if
bool
(
re
.
search
(
r
"\s*//#+"
,
line
)):
names
=
line
.
split
(
'|'
)
names
=
[
n
.
strip
()
for
n
in
names
]
if
not
param_names
:
param_names
=
[
""
]
*
len
(
names
)
param_names
=
[
a
+
b
for
a
,
b
in
zip
(
param_names
,
names
)]
elif
param_names
:
param_names
[
0
]
=
line
.
split
(
'<'
)[
0
].
strip
()
file
.
seek
(
0
)
return
param_names
[:
-
1
]
file
.
seek
(
0
)
return
param_names
[:
-
1
]
def
parse_device_batched_gemm_softmax_gemm_instances
(
source
,
out_dir
):
aliases
=
{
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"PassThrough"
:
"ck::tensor_operation::element_wise::PassThrough"
,
"Scale"
:
"ck::tensor_operation::element_wise::Scale"
,
"GemmPadded"
:
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"GemmDefault"
:
"ck::tensor_operation::device::GemmSpecialization::Default"
}
device_ops
=
{
"batched_gemm_softmax_gemm"
:
"DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
"permute"
in
op_name
or
op_name
not
in
device_ops
:
continue
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
instances_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
param_names
=
parse_param_names
(
f
)
# for i in range(len(param_names)):
# print(f"{i}: {param_names[i]}")
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
masking
=
new_line
.
replace
(
"Masking"
,
"true"
)
no_masking
=
new_line
.
replace
(
"Masking"
,
"false"
)
instances_list
.
append
(
masking
)
instances_list
.
append
(
no_masking
)
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_softmax_gemm_file
(
op_name
,
instances_name
,
"
\n
"
.
join
(
instances_list
),
include_header
))
def
run
(
args
):
parse_instances
(
args
[
0
],
args
[
1
])
parse_device_gemm_multiple_d_instances
(
args
[
0
],
args
[
1
])
parse_device_batched_gemm_softmax_gemm_instances
(
args
[
0
],
args
[
1
])
if
__name__
==
'__main__'
:
run
(
sys
.
argv
[
1
:])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment