Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ba251e4a
Commit
ba251e4a
authored
Sep 29, 2023
by
Umang Yadav
Browse files
Formatting and put find_package(hip) behind JIT_LIB flag
parent
000c8bcf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
108 deletions
+114
-108
CMakeLists.txt
CMakeLists.txt
+10
-9
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
...rary/include/ck/host/device_batched_gemm_softmax_gemm.hpp
+75
-74
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
.../src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
+17
-16
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+12
-9
No files found.
CMakeLists.txt
View file @
ba251e4a
...
...
@@ -108,15 +108,6 @@ if(GPU_TARGETS)
else
()
message
(
"Building CK for the following targets:
${
AMDGPU_TARGETS
}
"
)
endif
()
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
if
(
${
hip_VERSION_FLAT
}
GREATER 500723302
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
endif
()
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
...
...
@@ -147,6 +138,16 @@ message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
option
(
CK_BUILD_JIT_LIB,
"Only build the CK JIT Helper Library"
OFF
)
if
(
NOT CK_BUILD_JIT_LIB
)
find_package
(
hip
)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math
(
EXPR hip_VERSION_FLAT
"(
${
hip_VERSION_MAJOR
}
* 1000 +
${
hip_VERSION_MINOR
}
) * 100000 +
${
hip_VERSION_PATCH
}
"
)
message
(
"hip_version_flat=
${
hip_VERSION_FLAT
}
"
)
if
(
${
hip_VERSION_FLAT
}
GREATER 500723302
)
message
(
"Adding the fno-offload-uniform-block compiler flag"
)
add_compile_options
(
-fno-offload-uniform-block
)
endif
()
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
option
(
USE_OPT_NAVI3X,
"Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons."
OFF
)
...
...
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
View file @
ba251e4a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -17,23 +17,23 @@ namespace device_batched_gemm_softmax_gemm {
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B1ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
AccElementOp
=
"ck::tensor_operation::element_wise::Scale"
;
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B1ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
AccElementOp
=
"ck::tensor_operation::element_wise::Scale"
;
std
::
string
GetIncludeHeader
()
const
;
...
...
@@ -44,64 +44,65 @@ struct Problem
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
static
const
std
::
size_t
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx
=
0
;
static
const
std
::
size_t
ALayout_idx
=
1
;
static
const
std
::
size_t
B0Layout_idx
=
2
;
static
const
std
::
size_t
B1Layout_idx
=
3
;
static
const
std
::
size_t
CLayout_idx
=
4
;
static
const
std
::
size_t
ADataType_idx
=
5
;
static
const
std
::
size_t
B0DataType_idx
=
6
;
static
const
std
::
size_t
B1DataType_idx
=
7
;
static
const
std
::
size_t
CDataType_idx
=
8
;
static
const
std
::
size_t
AccDataType_idx
=
9
;
static
const
std
::
size_t
CShuffleDataType_idx
=
10
;
static
const
std
::
size_t
AElementwiseOperation_idx
=
11
;
static
const
std
::
size_t
B0ElementwiseOperation_idx
=
12
;
static
const
std
::
size_t
Acc0ElementwiseOperation_idx
=
13
;
static
const
std
::
size_t
B1ElementwiseOperation_idx
=
14
;
static
const
std
::
size_t
CElementwiseOperation_idx
=
15
;
static
const
std
::
size_t
GEMMSpecialization_idx
=
16
;
static
const
std
::
size_t
NumGemmKPrefetchStage_idx
=
17
;
static
const
std
::
size_t
BlockSize_idx
=
18
;
static
const
std
::
size_t
Gemm01MPerBlock_idx
=
19
;
static
const
std
::
size_t
Gemm0NPerBlock_idx
=
20
;
static
const
std
::
size_t
Gemm0KPerBlock_idx
=
21
;
static
const
std
::
size_t
Gemm1NPerBlock_idx
=
22
;
static
const
std
::
size_t
Gemm1KPerBlock_idx
=
23
;
static
const
std
::
size_t
AK1_idx
=
24
;
static
const
std
::
size_t
BK1_idx
=
25
;
static
const
std
::
size_t
B1K1_idx
=
26
;
static
const
std
::
size_t
MPerXDL_idx
=
27
;
static
const
std
::
size_t
NPerXDL_idx
=
28
;
static
const
std
::
size_t
Gemm0MXdlPerWave_idx
=
29
;
static
const
std
::
size_t
Gemm0NXdlPerWave_idx
=
30
;
static
const
std
::
size_t
Gemm1NXdlPerWave_idx
=
31
;
static
const
std
::
size_t
ABlockTransferThreadClusterLengths_K0_M_K1_idx
=
32
;
static
const
std
::
size_t
ABlockTransferThreadClusterArrangeOrder_idx
=
33
;
static
const
std
::
size_t
ABlockTransferSrcAccessOrder_idx
=
34
;
static
const
std
::
size_t
ABlockTransferSrcVectorDim_idx
=
35
;
static
const
std
::
size_t
ABlockTransferSrcScalarPerVector_idx
=
36
;
static
const
std
::
size_t
ABlockTransferDstScalarPerVector_K1_idx
=
37
;
static
const
std
::
size_t
ABlockLdsAddExtraM_idx
=
38
;
static
const
std
::
size_t
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx
=
0
;
static
const
std
::
size_t
ALayout_idx
=
1
;
static
const
std
::
size_t
B0Layout_idx
=
2
;
static
const
std
::
size_t
B1Layout_idx
=
3
;
static
const
std
::
size_t
CLayout_idx
=
4
;
static
const
std
::
size_t
ADataType_idx
=
5
;
static
const
std
::
size_t
B0DataType_idx
=
6
;
static
const
std
::
size_t
B1DataType_idx
=
7
;
static
const
std
::
size_t
CDataType_idx
=
8
;
static
const
std
::
size_t
AccDataType_idx
=
9
;
static
const
std
::
size_t
CShuffleDataType_idx
=
10
;
static
const
std
::
size_t
AElementwiseOperation_idx
=
11
;
static
const
std
::
size_t
B0ElementwiseOperation_idx
=
12
;
static
const
std
::
size_t
Acc0ElementwiseOperation_idx
=
13
;
static
const
std
::
size_t
B1ElementwiseOperation_idx
=
14
;
static
const
std
::
size_t
CElementwiseOperation_idx
=
15
;
static
const
std
::
size_t
GEMMSpecialization_idx
=
16
;
static
const
std
::
size_t
NumGemmKPrefetchStage_idx
=
17
;
static
const
std
::
size_t
BlockSize_idx
=
18
;
static
const
std
::
size_t
Gemm01MPerBlock_idx
=
19
;
static
const
std
::
size_t
Gemm0NPerBlock_idx
=
20
;
static
const
std
::
size_t
Gemm0KPerBlock_idx
=
21
;
static
const
std
::
size_t
Gemm1NPerBlock_idx
=
22
;
static
const
std
::
size_t
Gemm1KPerBlock_idx
=
23
;
static
const
std
::
size_t
AK1_idx
=
24
;
static
const
std
::
size_t
BK1_idx
=
25
;
static
const
std
::
size_t
B1K1_idx
=
26
;
static
const
std
::
size_t
MPerXDL_idx
=
27
;
static
const
std
::
size_t
NPerXDL_idx
=
28
;
static
const
std
::
size_t
Gemm0MXdlPerWave_idx
=
29
;
static
const
std
::
size_t
Gemm0NXdlPerWave_idx
=
30
;
static
const
std
::
size_t
Gemm1NXdlPerWave_idx
=
31
;
static
const
std
::
size_t
ABlockTransferThreadClusterLengths_K0_M_K1_idx
=
32
;
static
const
std
::
size_t
ABlockTransferThreadClusterArrangeOrder_idx
=
33
;
static
const
std
::
size_t
ABlockTransferSrcAccessOrder_idx
=
34
;
static
const
std
::
size_t
ABlockTransferSrcVectorDim_idx
=
35
;
static
const
std
::
size_t
ABlockTransferSrcScalarPerVector_idx
=
36
;
static
const
std
::
size_t
ABlockTransferDstScalarPerVector_K1_idx
=
37
;
static
const
std
::
size_t
ABlockLdsAddExtraM_idx
=
38
;
static
const
std
::
size_t
B0BlockTransferThreadClusterLengths_K0_N_K1_idx
=
39
;
static
const
std
::
size_t
B0BlockTransferThreadClusterArrangeOrder_idx
=
40
;
static
const
std
::
size_t
B0BlockTransferSrcAccessOrder_idx
=
41
;
static
const
std
::
size_t
B0BlockTransferSrcVectorDim_idx
=
42
;
static
const
std
::
size_t
B0BlockTransferSrcScalarPerVector_idx
=
43
;
static
const
std
::
size_t
B0BlockTransferDstScalarPerVector_K1_idx
=
44
;
static
const
std
::
size_t
B0BlockLdsAddExtraN_idx
=
45
;
static
const
std
::
size_t
B0BlockTransferThreadClusterArrangeOrder_idx
=
40
;
static
const
std
::
size_t
B0BlockTransferSrcAccessOrder_idx
=
41
;
static
const
std
::
size_t
B0BlockTransferSrcVectorDim_idx
=
42
;
static
const
std
::
size_t
B0BlockTransferSrcScalarPerVector_idx
=
43
;
static
const
std
::
size_t
B0BlockTransferDstScalarPerVector_K1_idx
=
44
;
static
const
std
::
size_t
B0BlockLdsAddExtraN_idx
=
45
;
static
const
std
::
size_t
B1BlockTransferThreadClusterLengths_K0_N_K1_idx
=
46
;
static
const
std
::
size_t
B1BlockTransferThreadClusterArrangeOrder_idx
=
47
;
static
const
std
::
size_t
B1BlockTransferSrcAccessOrder_idx
=
48
;
static
const
std
::
size_t
B1BlockTransferSrcVectorDim_idx
=
49
;
static
const
std
::
size_t
B1BlockTransferSrcScalarPerVector_idx
=
50
;
static
const
std
::
size_t
B1BlockTransferDstScalarPerVector_K1_idx
=
51
;
static
const
std
::
size_t
B1BlockLdsAddExtraN_idx
=
52
;
static
const
std
::
size_t
CShuffleMXdlPerWavePerShuffle_idx
=
53
;
static
const
std
::
size_t
CShuffleNXdlPerWavePerShuffle_idx
=
54
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
static
const
std
::
size_t
B1BlockTransferThreadClusterArrangeOrder_idx
=
47
;
static
const
std
::
size_t
B1BlockTransferSrcAccessOrder_idx
=
48
;
static
const
std
::
size_t
B1BlockTransferSrcVectorDim_idx
=
49
;
static
const
std
::
size_t
B1BlockTransferSrcScalarPerVector_idx
=
50
;
static
const
std
::
size_t
B1BlockTransferDstScalarPerVector_K1_idx
=
51
;
static
const
std
::
size_t
B1BlockLdsAddExtraN_idx
=
52
;
static
const
std
::
size_t
CShuffleMXdlPerWavePerShuffle_idx
=
53
;
static
const
std
::
size_t
CShuffleNXdlPerWavePerShuffle_idx
=
54
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
};
}
// namespace device_batched_gemm_softmax_gemm
...
...
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
View file @
ba251e4a
...
...
@@ -64,23 +64,24 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
params
[
AElementwiseOperation_idx
]
=
AElementOp
;
params
[
B0ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B1ElementwiseOperation_idx
]
=
BElementOp
;
params
[
CElementwiseOperation_idx
]
=
CElementOp
;
params
[
AElementwiseOperation_idx
]
=
AElementOp
;
params
[
B0ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B1ElementwiseOperation_idx
]
=
BElementOp
;
params
[
CElementwiseOperation_idx
]
=
CElementOp
;
params
[
Acc0ElementwiseOperation_idx
]
=
AccElementOp
;
auto
block_size_str
=
params
[
BlockSize_idx
];
auto
m_per_block_str
=
params
[
Gemm01MPerBlock_idx
];
auto
n_per_block_str
=
params
[
Gemm0NPerBlock_idx
];
auto
k_per_block_str
=
params
[
Gemm0KPerBlock_idx
];
auto
n1_per_block_str
=
params
[
Gemm1NPerBlock_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
auto
block_size_str
=
params
[
BlockSize_idx
];
auto
m_per_block_str
=
params
[
Gemm01MPerBlock_idx
];
auto
n_per_block_str
=
params
[
Gemm0NPerBlock_idx
];
auto
k_per_block_str
=
params
[
Gemm0KPerBlock_idx
];
auto
n1_per_block_str
=
params
[
Gemm1NPerBlock_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
...
...
library/src/jit_library/src/device_gemm_multiple_d.cpp
View file @
ba251e4a
...
...
@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
EDataType
==
DataType
::
Half
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
if
(
EDataType
==
DataType
::
Half
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
EDataType
==
DataType
::
Float
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
if
(
EDataType
==
DataType
::
Float
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
if
(
EDataType
==
DataType
::
Int32
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Int32
;
}))
if
(
EDataType
==
DataType
::
Int32
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Int32
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
...
...
@@ -134,14 +137,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
if
(
params
.
back
().
find
(
"v2"
)
!=
std
::
string
::
npos
and
K
%
k_per_block
!=
0
)
if
(
params
.
back
().
find
(
"v2"
)
!=
std
::
string
::
npos
and
K
%
k_per_block
!=
0
)
str
=
""
;
return
Solution
{
str
,
block_size
,
grid_size
};
...
...
@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
auto
solution
=
MakeSolution
(
i
,
arch
);
if
(
solution
.
template_str
!=
""
)
if
(
solution
.
template_str
!=
""
)
solutions
.
push_back
(
solution
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment