Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d9676215
Commit
d9676215
authored
Sep 20, 2023
by
Alan Turner
Browse files
Add Descriptor and Run to device op
parent
611196d5
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
110 additions
and
55 deletions
+110
-55
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+96
-35
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+1
-1
library/src/jit_library/CMakeLists.txt
library/src/jit_library/CMakeLists.txt
+1
-0
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
...rary/include/ck/host/device_batched_gemm_softmax_gemm.hpp
+10
-10
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
.../src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
+2
-9
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
d9676215
...
@@ -662,7 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -662,7 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
c_extent_lowest
=
const
auto
c_extent_lowest
=
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>
?
Gemm1NRaw
:
MRaw
;
if
constexpr
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
...
@@ -857,26 +857,83 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -857,26 +857,83 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
MakeCGridDescriptor_M_N
(
CDesc
{})))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
// GridwiseGemm
MakeCGridDescriptor_M_N
(
CDesc
{})))
>
;
using
GridwiseGemmSpec
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
C0MatrixMask
=
conditional_t
<
MaskOutUpperTriangle
,
ADataType
,
// TODO: distinguish A/B datatype
C0MatrixMask_impl
<
MaskOutUpperTrianglePredicate
>
,
GemmAccDataType
,
C0MatrixMask_impl
<
MaskDisabledPredicate
>>
;
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
CGridDesc_M_N
c_grid_desc_m_n
;
Block2CTileMap
block_2_ctile_map
;
C0MatrixMask
c0_matrix_mask
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemmSpec
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemmSpec
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op
;
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
BElementwiseOperation
b_element_op
;
AccElementwiseOperation
acc_element_op
;
B1ElementwiseOperation
b1_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
CElementwiseOperation
c_element_op
;
...
@@ -889,31 +946,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -889,31 +946,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CDesc
c
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
BElementwiseOperation
b_element_op_
,
AccElementwiseOperation
acc_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
MakeCGridDescriptor_M_N
(
c
))},
block_2_ctile_map
{
GridwiseGemmSpec
::
MakeDefaultBlock2CTileMap
(
block_2_etile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
MakeCGridDescriptor_M_N
(
c
))},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
GridwiseGemmSpec
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemmSpec
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)}
c0_matrix_mask
{
c
.
GetLength
(
I1
)}
,
a_element_op
{
a_element_op_
},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b_element_op
{
b_element_op_
},
acc_element_op
{
acc_element_op_
},
b1_element_op
{
b1_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
is_valid
{
GridwiseGemm
Spec
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
MakeCGridDescriptor_M_N
(
c
),
c_grid_desc_m_n
,
block_2_ctile_map
)
and
block_2_ctile_map
)}
IsSupported
(
c
.
GetLength
(
I0
),
c
.
GetLength
(
I1
),
a
.
GetLength
(
I1
),
b1
.
GetLength
(
I1
))}
{
{
}
}
...
@@ -927,37 +982,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -927,37 +982,43 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
static
constexpr
auto
make_descriptor
(
ADesc
a
,
make_descriptor
(
ADesc
a
,
BDesc
b
,
BDesc
b
,
B1Desc
b1
desc
,
B1Desc
b1
,
CDesc
c
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
AccElementwiseOperation
acc_element_op
=
AccElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
}
template
<
class
Desc
>
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
CDataType
*
__restrict__
p_c_grid
)
{
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
is_valid
and
assert
(
desc
.
is_valid
);
IsSupported
(
desc
.
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
desc
.
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
desc
.
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
desc
.
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
desc
.
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)));
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemmSpec
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
if
(
desc
.
has_main_k_block_loop
)
{
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
Desc
::
GridwiseGemm
Spec
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
_block
,
desc
.
a_element_op
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
b_element_op
,
desc
.
acc_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
a_grid_desc_ak0_m_ak1
,
...
@@ -969,14 +1030,14 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -969,14 +1030,14 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
Desc
::
GridwiseGemm
Spec
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_b1_grid
,
p_b1_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
_block
,
desc
.
a_element_op
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
b_element_op
,
desc
.
acc_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
a_grid_desc_ak0_m_ak1
,
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
d9676215
...
@@ -53,7 +53,7 @@ struct MaskOutUpperTrianglePredicate
...
@@ -53,7 +53,7 @@ struct MaskOutUpperTrianglePredicate
template
<
typename
MaskOutPredicate
>
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
struct
C0MatrixMask_impl
{
{
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
constexpr
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{}
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
__host__
__device__
constexpr
bool
IsNOutOfBound
(
/*index_t m, */
index_t
n
)
const
{
{
...
...
library/src/jit_library/CMakeLists.txt
View file @
d9676215
...
@@ -13,6 +13,7 @@ execute_process(
...
@@ -13,6 +13,7 @@ execute_process(
)
)
add_library
(
jit_library STATIC
add_library
(
jit_library STATIC
src/device_batched_gemm_softmax_gemm.cpp
src/device_gemm_multiple_d.cpp
src/device_gemm_multiple_d.cpp
src/common.cpp
src/common.cpp
)
)
...
...
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
View file @
d9676215
...
@@ -33,7 +33,16 @@ struct Problem
...
@@ -33,7 +33,16 @@ struct Problem
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B1ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B1ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
float
scale
=
1.0
;
std
::
string
AccElementOp
=
"ck::tensor_operation::element_wise::Scale"
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
static
const
std
::
size_t
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx
=
0
;
static
const
std
::
size_t
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx
=
0
;
static
const
std
::
size_t
ALayout_idx
=
1
;
static
const
std
::
size_t
ALayout_idx
=
1
;
...
@@ -93,15 +102,6 @@ struct Problem
...
@@ -93,15 +102,6 @@ struct Problem
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
};
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace device_batched_gemm_softmax_gemm
...
...
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
View file @
d9676215
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp"
#include "ck/host/common.hpp"
#include "
gemm_add_add_fastgelu
_instances.hpp"
#include "
batched_gemm_softmax_gemm
_instances.hpp"
#include <algorithm>
#include <algorithm>
#include <unordered_set>
#include <unordered_set>
...
@@ -57,11 +57,6 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
...
@@ -57,11 +57,6 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
return
instances
;
return
instances
;
}
}
std
::
string
GetElementwiseScaleString
(
const
float
s
)
{
return
"ck::tensor_operation::element_wise::Scale{"
+
std
::
to_string
(
s
)
+
"}"
;
}
Solution
Problem
::
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
Solution
Problem
::
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
{
{
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
...
@@ -73,19 +68,17 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
...
@@ -73,19 +68,17 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
params
[
B0ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B0ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B1ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B1ElementwiseOperation_idx
]
=
BElementOp
;
params
[
CElementwiseOperation_idx
]
=
CElementOp
;
params
[
CElementwiseOperation_idx
]
=
CElementOp
;
params
[
Acc0ElementwiseOperation_idx
]
=
Get
Element
wiseScaleString
(
scale
)
;
params
[
Acc0ElementwiseOperation_idx
]
=
Acc
Element
Op
;
auto
block_size_str
=
params
[
BlockSize_idx
];
auto
block_size_str
=
params
[
BlockSize_idx
];
auto
m_per_block_str
=
params
[
Gemm01MPerBlock_idx
];
auto
m_per_block_str
=
params
[
Gemm01MPerBlock_idx
];
auto
n_per_block_str
=
params
[
Gemm0NPerBlock_idx
];
auto
n_per_block_str
=
params
[
Gemm0NPerBlock_idx
];
auto
k_per_block_str
=
params
[
Gemm0KPerBlock_idx
];
auto
k_per_block_str
=
params
[
Gemm0KPerBlock_idx
];
auto
n1_per_block_str
=
params
[
Gemm1NPerBlock_idx
];
auto
n1_per_block_str
=
params
[
Gemm1NPerBlock_idx
];
auto
k1_per_block_str
=
params
[
Gemm1KPerBlock_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
k1_per_block
=
std
::
stoi
(
k1_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment