Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
77190058
Commit
77190058
authored
Jun 07, 2023
by
Alan Turner
Browse files
Formatting
parent
421734ae
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
366 additions
and
330 deletions
+366
-330
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+14
-10
library/src/jit_library/include/ck/host/common.hpp
library/src/jit_library/include/ck/host/common.hpp
+3
-2
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
...rc/jit_library/include/ck/host/device_gemm_multiple_d.hpp
+25
-26
library/src/jit_library/src/common.cpp
library/src/jit_library/src/common.cpp
+7
-6
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+39
-37
test/jit_library/jit_library.cpp
test/jit_library/jit_library.cpp
+278
-249
No files found.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
77190058
...
...
@@ -120,22 +120,24 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
...
...
@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
...
...
library/src/jit_library/include/ck/host/common.hpp
View file @
77190058
...
...
@@ -17,7 +17,8 @@ struct Solution
std
::
size_t
grid_size
;
};
enum
class
DataType
{
enum
class
DataType
{
Half
,
Float
,
Int8
,
...
...
@@ -26,7 +27,7 @@ enum class DataType {
std
::
string
ToString
(
DataType
dt
);
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
...
...
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
View file @
77190058
...
...
@@ -11,7 +11,6 @@
#include <numeric>
#include "ck/host/common.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
...
...
@@ -49,7 +48,7 @@ struct Problem
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
...
...
library/src/jit_library/src/common.cpp
View file @
77190058
...
...
@@ -8,7 +8,8 @@ namespace host {
std
::
string
ToString
(
DataType
dt
)
{
switch
(
dt
)
{
switch
(
dt
)
{
case
DataType
::
Float
:
return
"float"
;
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Int8
:
return
"int8_t"
;
...
...
@@ -17,7 +18,7 @@ std::string ToString(DataType dt)
throw
std
::
runtime_error
(
"Incorrect data type"
);
}
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
{
return
ck_headers
();
}
...
...
library/src/jit_library/src/device_gemm_multiple_d.cpp
View file @
77190058
...
...
@@ -33,8 +33,7 @@ std::size_t GetGridSize(const std::size_t m,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
...
...
@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
...
...
@@ -68,9 +67,10 @@ std::string MakeLayoutTuple(const std::vector<bool>& layouts)
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
}
...
...
@@ -85,7 +85,7 @@ std::string MakeTypeTuple(const std::vector<DataType>& types)
{
type_tuple
+=
ToString
(
*
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
}
return
type_tuple
+
">"
;
...
...
@@ -98,14 +98,16 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
...
...
@@ -128,10 +130,11 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
...
...
@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
...
...
@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
return
solutions
;
}
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
test/jit_library/jit_library.cpp
View file @
77190058
...
...
@@ -3,7 +3,8 @@
bool
test_Problem
()
{
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -27,9 +28,20 @@ bool test_Problem()
bool
pass
=
true
;
pass
&=
include_header
==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
pass
&=
include_header
==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
pass
&=
solutions
.
size
()
==
42
;
pass
&=
template_str
==
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"
;
pass
&=
template_str
==
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< "
"ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, "
"ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, "
"ck::half_t, ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, "
"8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, "
"8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, "
"1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"
;
pass
&=
grid_size
==
2
;
pass
&=
block_size
==
256
;
...
...
@@ -40,8 +52,9 @@ bool test_GetGemmSpec()
{
bool
pass
=
true
;
{
//PadMNK
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
255
,
// PadMNK
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
255
,
255
,
255
,
false
,
...
...
@@ -62,8 +75,9 @@ bool test_GetGemmSpec()
pass
&=
template_str
.
find
(
"GemmSpecialization::MNKPadding"
)
!=
std
::
string
::
npos
;
}
{
//Default
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Default
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -91,8 +105,9 @@ bool test_GetInstances()
{
bool
pass
=
true
;
{
//Col Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Col Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
...
...
@@ -109,8 +124,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
}
{
//Col Row Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Col Row Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
...
...
@@ -127,8 +143,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
}
{
//Row Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Row Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -145,8 +162,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
42
;
}
{
//Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -163,8 +181,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
{
//Col Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Col Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
...
...
@@ -181,8 +200,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
{
//Col Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Col Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
...
...
@@ -199,8 +219,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
{
//Row Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Row Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -217,8 +238,9 @@ bool test_GetInstances()
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
39
;
}
{
//Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
// Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -243,7 +265,8 @@ bool test_MakeLayoutsTuple()
bool
pass
=
true
;
{
// Empty Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -264,7 +287,8 @@ bool test_MakeLayoutsTuple()
}
{
// RowColRow Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -281,7 +305,10 @@ bool test_MakeLayoutsTuple()
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>"
)
!=
std
::
string
::
npos
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, "
"ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>"
)
!=
std
::
string
::
npos
;
}
return
pass
;
...
...
@@ -292,7 +319,8 @@ bool test_MakeTypeTuple()
bool
pass
=
true
;
{
// Empty Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
@@ -313,7 +341,8 @@ bool test_MakeTypeTuple()
}
{
// Half Int8 Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment