Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
adbefd90
Commit
adbefd90
authored
May 11, 2023
by
Astha Rai
Browse files
modularized input for examples, compiles into .so
parent
f945f40a
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1169 additions
and
96 deletions
+1169
-96
python/ait_impl/generation/ex/normal/Makefile
python/ait_impl/generation/ex/normal/Makefile
+16
-14
python/ait_impl/generation/ex/normal/ex.cpp
python/ait_impl/generation/ex/normal/ex.cpp
+183
-0
python/ait_impl/generation/ex/normal/gemm
python/ait_impl/generation/ex/normal/gemm
+0
-0
python/ait_impl/generation/ex/normal/gemm_op.py
python/ait_impl/generation/ex/normal/gemm_op.py
+144
-0
python/ait_impl/generation/ex/shared/128.cpp
python/ait_impl/generation/ex/shared/128.cpp
+183
-0
python/ait_impl/generation/ex/shared/256.cpp
python/ait_impl/generation/ex/shared/256.cpp
+183
-0
python/ait_impl/generation/ex/shared/Makefile
python/ait_impl/generation/ex/shared/Makefile
+14
-5
python/ait_impl/generation/ex/shared/__pycache__/ck_types.cpython-38.pyc
.../generation/ex/shared/__pycache__/ck_types.cpython-38.pyc
+0
-0
python/ait_impl/generation/ex/shared/__pycache__/gemm_ex.cpython-38.pyc
...l/generation/ex/shared/__pycache__/gemm_ex.cpython-38.pyc
+0
-0
python/ait_impl/generation/ex/shared/__pycache__/gemm_op.cpython-38.pyc
...l/generation/ex/shared/__pycache__/gemm_op.cpython-38.pyc
+0
-0
python/ait_impl/generation/ex/shared/__pycache__/user.cpython-38.pyc
...impl/generation/ex/shared/__pycache__/user.cpython-38.pyc
+0
-0
python/ait_impl/generation/ex/shared/ck_types.py
python/ait_impl/generation/ex/shared/ck_types.py
+17
-0
python/ait_impl/generation/ex/shared/driver.py
python/ait_impl/generation/ex/shared/driver.py
+24
-0
python/ait_impl/generation/ex/shared/gemm_ex.py
python/ait_impl/generation/ex/shared/gemm_ex.py
+66
-77
python/ait_impl/generation/ex/shared/gemm_op.py
python/ait_impl/generation/ex/shared/gemm_op.py
+144
-0
python/ait_impl/generation/ex/shared/make_template.py
python/ait_impl/generation/ex/shared/make_template.py
+75
-0
python/ait_impl/generation/ex/shared/user.py
python/ait_impl/generation/ex/shared/user.py
+120
-0
No files found.
python/ait_impl/generation/ex/normal/Makefile
View file @
adbefd90
CFLAGS
=
-I
~/workspace/composable_kernel/include
-I
/opt/workspace/rocm-5.1.1/hip/include
-I
~/workspace/composable_kernel/include/
-I
~/workspace/composable_kernel/include/ck/
-I
~/workspace/composable_kernel/example/01_gemm/
-I
~/workspace/composable_kernel/library/include/
-I
~/workspace/composable_kernel/library/src/utility/
-I
~/workspace/composable_kernel/include/ck/problem_transform/
-I
~/workspace/composable_kernel/include/ck/tensor/
-I
~/workspace/composable_kernel/include/ck/tensor_description/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/
-I
~/workspace/composable_kernel/include/ck/host_utility
-I
/external/include/half/
-I
~/workspace/composable_kernel/library/include/ck/library/host/
-I
~/workspace/composable_kernel/library/include/ck/library/host_tensor/
-I
~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/
-I
~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/
-I
~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/
" + "
reduce/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_op/
-I
~/workspace/composable_kernel/library/include/ck/library/utility/
-I
~/workspace/composable_kernel/profiler/include/
CXXFLAGS
=
-std
=
c++17
gemm
:
ex.o host_tensor.o device_memory.o
hipcc
$(CXXFLAGS)
$(CFLAGS)
ex.o host_tensor.o device_memory.o
-o
gemm
device_memory.o
:
../../../../../library/src/utility/device_memory.cpp
hipcc
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/device_memory.cpp
host_tensor.o
:
../../../../../library/src/utility/host_tensor.cpp
hipcc
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/host_tensor.cpp
ex.o
:
hipcc
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
-w
/opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc
$(CFLAGS)
-L
/opt/rocm-5.3.0/rocrand
-lrocrand
-x
hip
-c
ex.cpp
CFLAGS
=
-I
~/workspace/composable_kernel/include
-I
/opt/workspace/rocm-5.1.1/hip/include
-I
~/workspace/composable_kernel/include/
-I
~/workspace/composable_kernel/include/ck/
-I
~/workspace/composable_kernel/example/01_gemm/
-I
~/workspace/composable_kernel/library/include/
-I
~/workspace/composable_kernel/library/src/utility/
-I
~/workspace/composable_kernel/include/ck/problem_transform/
-I
~/workspace/composable_kernel/include/ck/tensor/
-I
~/workspace/composable_kernel/include/ck/tensor_description/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/
-I
~/workspace/composable_kernel/include/ck/host_utility
-I
/external/include/half/
-I
~/workspace/composable_kernel/library/include/ck/library/host/
-I
~/workspace/composable_kernel/library/include/ck/library/host_tensor/
-I
~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/
-I
~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/
-I
~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/
" + "
reduce/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_op/
-I
~/workspace/composable_kernel/library/include/ck/library/utility/
-I
~/workspace/composable_kernel/profiler/include/
CXXFLAGS
=
-std
=
c++17
gemm
:
ex.o host_tensor.o device_memory.o
hipcc
$(CXXFLAGS)
$(CFLAGS)
ex.o host_tensor.o device_memory.o
-o
gemm
device_memory.o
:
../../../../../library/src/utility/device_memory.cpp
hipcc
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/device_memory.cpp
host_tensor.o
:
../../../../../library/src/utility/host_tensor.cpp
hipcc
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/host_tensor.cpp
ex.o
:
hipcc
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
-w
/opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc
$(CFLAGS)
-L
/opt/rocm-5.3.0/rocrand
-lrocrand
-x
hip
-c
ex.cpp
\ No newline at end of file
python/ait_impl/generation/ex/normal/ex.cpp
0 → 100644
View file @
adbefd90
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmDl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
bool
run_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
ck
::
half_t
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
{}));
Tensor
<
ck
::
half_t
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
ck
::
tensor_layout
::
gemm
::
RowMajor
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ck
::
half_t
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ck
::
half_t
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
default:
ck
::
utils
::
FillUniformDistribution
<
ck
::
half_t
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ck
::
half_t
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
Tensor
<
ck
::
half_t
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
ck
::
half_t
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ck
::
half_t
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ck
::
half_t
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
ck
::
half_t
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
auto
a_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
auto
b_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
auto
c_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ck
::
half_t
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ck
::
half_t
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ck
::
half_t
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
2
_uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ck
::
half_t
)
*
M
*
K
+
sizeof
(
ck
::
half_t
)
*
K
*
N
+
sizeof
(
ck
::
half_t
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
}
return
true
;
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
python/ait_impl/generation/ex/normal/gemm
0 → 100755
View file @
adbefd90
File added
python/ait_impl/generation/ex/normal/gemm_op.py
0 → 100644
View file @
adbefd90
#take in input for gemm from user, send it to example template
import
enum
import
ck_types
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
enum
import
auto
from
typing
import
List
from
ck_types
import
*
class
GemmType
():
GemmDefault
=
"ck::tensor_operation::device::GemmSpecialization::Default"
# class GemmSpecialization(enum.Enum):
# GemmDefault = auto()
# MNKPadding = auto()
# MNPadding = auto()
# MNOPadding = auto()
# MNKOPadding = auto()
# GemmSpecializationTag = {
# GemmSpecialization.GemmDefault: "ck::tensor_operation::device::GemmSpecialization::Default",
# GemmSpecialization.MNKPadding: "ck::tensor_operation::device::GemmSpecialization::MNKPadding",
# GemmSpecialization.MNPadding: "ck::tensor_operation::device::GemmSpecialization::MNPadding",
# GemmSpecialization.MNOPadding: "ck::tensor_operation::device::GemmSpecialization::MNOPadding",
# GemmSpecialization.MNKOPadding: "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
# }
@
dataclass
class
TileDesc
:
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
k1
:
int
m_per_thread
:
int
n_per_thread
:
int
k_per_thread
:
int
m1n1_thcluster_m1xs
:
str
m1n1_thcluster_n1xs
:
str
def
__str__
(
self
)
->
str
:
values
=
list
(
self
.
__dict__
.
values
())
return
"_"
.
join
([
str
(
x
)
for
x
in
values
])
return
template
.
render
(
param
=
args
)
@
dataclass
class
BlockTransferDesc
:
thread_slice_length
:
str
thread_cluster_length
:
str
thread_cluster_arrange_order
:
str
src_access_order
:
str
src_vec_tensor_lengths
:
str
src_vec_tensor_cont_dim_order
:
str
dst_vec_tensor_lengths
:
str
def
__str__
(
self
)
->
str
:
args
=
deepcopy
(
self
.
__dict__
)
args
[
"thread_cluster_length"
]
=
[
str
(
x
)
for
x
in
self
.
thread_cluster_length
]
args
[
"thread_cluster_arrange_order"
]
=
[
str
(
x
)
for
x
in
self
.
thread_cluster_arrange_order
]
args
[
"src_access_order"
]
=
[
str
(
x
)
for
x
in
self
.
src_access_order
]
@
dataclass
class
CBlockTransferDesc
:
src_dst_access_order
:
str
src_dst_vec_dim
:
int
dst_scalar_per_vector
:
int
def
__str__
(
self
)
->
str
:
args
=
deepcopy
(
self
.
__dict__
)
#args["m_n_block_wave_per_xdl"] = [str(x) for x in self.m_n_block_wave_per_xdl]
@
dataclass
class
GemmOperation
:
A
:
TensorDesc
B
:
TensorDesc
C
:
TensorDesc
a_elem_op
:
TensorOperation
b_elem_op
:
TensorOperation
epilogue_functor
:
TensorOperation
gemm_specialization
:
GemmType
#GemmSpecialization
tile_desc
:
TileDesc
a_block_transfer
:
BlockTransferDesc
b_block_transfer
:
BlockTransferDesc
b1_block_transfer
:
BlockTransferDesc
=
None
c_block_transfer
:
CBlockTransferDesc
=
None
def
__str__
(
self
)
->
str
:
io_name
=
"{gemm_kind}_{gemm_specialization}_{a_dtype}{b_dtype}{c_dtype}_{a_layout}{b_layout}{c_layout}"
.
format
(
#gemm_kind=library.GemmKindNames[self.operation_kind],
gemm_specialization
=
self
.
gemm_specialization
.
value
,
a_dtype
=
[
self
.
A
.
element
],
b_dtype
=
[
self
.
B
.
element
],
c_dtype
=
[
self
.
C
.
element
],
a_layout
=
[
self
.
A
.
layout
],
b_layout
=
[
self
.
B
.
layout
],
c_layout
=
[
self
.
C
.
layout
],
)
extra_tile
=
""
if
self
.
c_block_transfer
is
not
None
:
if
self
.
c_block_transfer
.
scalar_per_vector
==
4
:
extra_tile
=
"_C4"
elif
self
.
c_block_transfer
.
scalar_per_vector
==
1
:
extra_tile
=
"_C1"
tile_name
=
str
(
self
.
tile_desc
)
+
extra_tile
return
"{io_name}_{tile_name}_{epilogue_functor}"
.
format
(
io_name
=
io_name
,
tile_name
=
tile_name
,
epilogue_functor
=
[
self
.
epilogue_functor
],
)
def
accumulator_type
(
self
):
return
DataType
.
f16
#f.32?
if
__name__
==
"__main__"
:
A
=
TensorDesc
(
DataType
.
f16
,
Layout
.
RowMajor
)
B
=
TensorDesc
(
DataType
.
f16
,
Layout
.
ColumnMajor
)
C
=
TensorDesc
(
DataType
.
f16
,
Layout
.
RowMajor
)
GemmOp
=
GemmOperation
(
A
=
A
,
B
=
B
,
C
=
C
,
a_elem_op
=
TensorOperation
.
PassThrough
,
b_elem_op
=
TensorOperation
.
PassThrough
,
epilogue_functor
=
TensorOperation
.
PassThrough
,
gemm_specialization
=
GemmType
.
GemmDefault
,
tile_desc
=
TileDesc
(
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
),
a_block_transfer
=
BlockTransferDesc
(
[
4
,
64
,
1
],
[
1
,
0
,
2
],
[
1
,
0
,
2
],
2
,
8
,
8
,
1
,
True
),
b_block_transfer
=
BlockTransferDesc
(
[
8
,
32
,
1
],
[
0
,
2
,
1
],
[
0
,
2
,
1
],
1
,
4
,
1
,
0
,
True
),
c_block_transfer
=
CBlockTransferDesc
(
1
,
1
,
[
1
,
32
,
1
,
8
],
8
),
#ds_dtype=[DataType.f16],
)
print
(
GemmOp
.
a_elem_op
)
python/ait_impl/generation/ex/shared/128.cpp
0 → 100644
View file @
adbefd90
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmDl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
,
128
,
128
,
128
,
32
,
2
,
32
,
32
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
1
,
2
,
3
,
5
,
5
,
6
>
,
6
,
5
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
bool
run_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
ck
::
half_t
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
{}));
Tensor
<
ck
::
half_t
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
ck
::
tensor_layout
::
gemm
::
RowMajor
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ck
::
half_t
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ck
::
half_t
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
default:
ck
::
utils
::
FillUniformDistribution
<
ck
::
half_t
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ck
::
half_t
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
Tensor
<
ck
::
half_t
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
ck
::
half_t
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ck
::
half_t
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ck
::
half_t
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
ck
::
half_t
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
auto
a_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
auto
b_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
auto
c_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ck
::
half_t
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ck
::
half_t
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ck
::
half_t
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
2
_uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ck
::
half_t
)
*
M
*
K
+
sizeof
(
ck
::
half_t
)
*
K
*
N
+
sizeof
(
ck
::
half_t
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
}
return
true
;
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
python/ait_impl/generation/ex/shared/256.cpp
0 → 100644
View file @
adbefd90
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmDl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
bool
run_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
ck
::
half_t
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
{}));
Tensor
<
ck
::
half_t
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
ck
::
tensor_layout
::
gemm
::
RowMajor
{}));
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ck
::
half_t
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ck
::
half_t
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
default:
ck
::
utils
::
FillUniformDistribution
<
ck
::
half_t
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ck
::
half_t
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
}
Tensor
<
ck
::
half_t
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
ck
::
half_t
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ck
::
half_t
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ck
::
half_t
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
ck
::
half_t
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
auto
a_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
auto
b_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
auto
c_element_op
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ck
::
half_t
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ck
::
half_t
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ck
::
half_t
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
2
_uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ck
::
half_t
)
*
M
*
K
+
sizeof
(
ck
::
half_t
)
*
K
*
N
+
sizeof
(
ck
::
half_t
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
}
return
true
;
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
python/ait_impl/generation/ex/shared/Makefile
View file @
adbefd90
CFLAGS
=
-I
~/workspace/composable_kernel/include
-I
/opt/workspace/rocm-5.1.1/hip/include
-I
~/workspace/composable_kernel/include/
-I
~/workspace/composable_kernel/include/ck/
-I
~/workspace/composable_kernel/example/01_gemm/
-I
~/workspace/composable_kernel/library/include/
-I
~/workspace/composable_kernel/library/src/utility/
-I
~/workspace/composable_kernel/include/ck/problem_transform/
-I
~/workspace/composable_kernel/include/ck/tensor/
-I
~/workspace/composable_kernel/include/ck/tensor_description/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/
-I
~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/
-I
~/workspace/composable_kernel/include/ck/host_utility
-I
/external/include/half/
-I
~/workspace/composable_kernel/library/include/ck/library/host/
-I
~/workspace/composable_kernel/library/include/ck/library/host_tensor/
-I
~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/
-I
~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/
-I
~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/
" + "
reduce/
-I
~/workspace/composable_kernel/library/include/ck/library/tensor_op/
-I
~/workspace/composable_kernel/library/include/ck/library/utility/
-I
~/workspace/composable_kernel/profiler/include/
CXXFLAGS
=
-std
=
c++17
test.so
:
ex.o host_tensor.o device_memory.o
hipcc
-shared
$(CXXFLAGS)
$(CFLAGS)
ex.o host_tensor.o device_memory.o
-o
test.so
device_memory.o
:
../../../../../library/src/utility/device_memory.cpp
hipcc
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/device_memory.cpp
host_tensor.o
:
../../../../../library/src/utility/host_tensor.cpp
hipcc
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/host_tensor.cpp
hipcc
-shared
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
$(CFLAGS)
-c
../../../../../library/src/utility/host_tensor.cpp
obj_files
=
256.o
%.o
:
%.cpp
hipcc
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
-w
/opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc
$(CFLAGS)
-L
/opt/rocm-5.3.0/rocrand
-lrocrand
-x
hip
-c
$<
all
:
test.so
test.so
:
$(obj_files) host_tensor.o device_memory.o
hipcc
-shared
$(CXXFLAGS)
$(CFLAGS)
-o
$@
$(obj_files)
host_tensor.o device_memory.o
clean
:
rm
-f
*
.o test.so
ex.o
:
hipcc
-fPIC
-fvisibility
=
hidden
$(CXXFLAGS)
-w
/opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc
$(CFLAGS)
-L
/opt/rocm-5.3.0/rocrand
-lrocrand
-x
hip
-c
ex.cpp
python/ait_impl/generation/ex/shared/__pycache__/ck_types.cpython-38.pyc
0 → 100644
View file @
adbefd90
File added
python/ait_impl/generation/ex/shared/__pycache__/gemm_ex.cpython-38.pyc
0 → 100644
View file @
adbefd90
File added
python/ait_impl/generation/ex/shared/__pycache__/gemm_op.cpython-38.pyc
0 → 100644
View file @
adbefd90
File added
python/ait_impl/generation/ex/shared/__pycache__/user.cpython-38.pyc
0 → 100644
View file @
adbefd90
File added
python/ait_impl/generation/ex/shared/ck_types.py
0 → 100644
View file @
adbefd90
from
dataclasses
import
dataclass
class
DataType
:
f16
=
"ck::half_t"
class
Layout
:
ColumnMajor
=
"ck::tensor_layout::gemm::ColumnMajor"
RowMajor
=
"ck::tensor_layout::gemm::RowMajor"
class
TensorOperation
:
PassThrough
=
"ck::tensor_operation::element_wise::PassThrough"
@
dataclass
class
TensorDesc
:
#set up and import properly
element
:
DataType
layout
:
Layout
python/ait_impl/generation/ex/shared/driver.py
0 → 100644
View file @
adbefd90
import
enum
import
ck_types
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
enum
import
auto
from
typing
import
List
import
os.path
import
shutil
import
functools
import
operator
import
collections
import
subprocess
import
re
import
gemm_op
from
gemm_op
import
*
import
user
from
ck_types
import
*
from
gemm_ex
import
*
# holds multiple gemm instances
op_collection
=
user
.
CreateGemmOperator
()
for
op
in
op_collection
:
x
=
EmitGemmInstance
()
x
.
emit
(
op
)
\ No newline at end of file
python/ait_impl/generation/ex/shared/gemm_ex.py
View file @
adbefd90
...
...
@@ -6,6 +6,9 @@ import operator
import
collections
import
subprocess
import
re
import
gemm_op
from
gemm_op
import
*
import
user
def
SubstituteTemplate
(
template
,
values
):
text
=
template
...
...
@@ -23,22 +26,6 @@ def SubstituteTemplate(template, values):
class
EmitGemmInstance
:
def
__init__
(
self
):
self
.
make_template
=
"""
CFLAGS=-I ~/workspace/composable_kernel/include -I /opt/workspace/rocm-5.1.1/hip/include -I ~/workspace/composable_kernel/include/ -I ~/workspace/composable_kernel/include/ck/ -I ~/workspace/composable_kernel/example/01_gemm/ -I ~/workspace/composable_kernel/library/include/ -I ~/workspace/composable_kernel/library/src/utility/ -I ~/workspace/composable_kernel/include/ck/problem_transform/ -I ~/workspace/composable_kernel/include/ck/tensor/ -I ~/workspace/composable_kernel/include/ck/tensor_description/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/workspace/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/workspace/composable_kernel/library/include/ck/library/host/ -I ~/workspace/composable_kernel/library/include/ck/library/host_tensor/ -I ~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_op/ -I ~/workspace/composable_kernel/library/include/ck/library/utility/ -I ~/workspace/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
gemm: ex.o host_tensor.o device_memory.o
hipcc $(CXXFLAGS) $(CFLAGS) ex.o host_tensor.o device_memory.o -o gemm
device_memory.o: ../../../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/device_memory.cpp
host_tensor.o: ../../../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/host_tensor.cpp
ex.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
"""
self
.
gemm_devop_template
=
"""
#pragma once
...
...
@@ -223,71 +210,73 @@ bool run_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
"""
def
emit
(
self
):
def
emit
(
self
,
operation
):
values
=
{
'type_a'
:
'ck::half_t'
,
'type_b'
:
'ck::half_t'
,
'type_c'
:
'ck::half_t'
,
'type_a'
:
operation
.
A
.
element
,
'type_b'
:
operation
.
B
.
element
,
'type_c'
:
operation
.
C
.
element
,
'type_acc'
:
'float'
,
'layout_a'
:
'ck::tensor_layout::gemm::ColumnMajor'
,
'layout_b'
:
'ck::tensor_layout::gemm::RowMajor'
,
'layout_c'
:
'ck::tensor_layout::gemm::RowMajor'
,
'elementwise_op_a'
:
'ck::tensor_
operation
::
elem
ent_wise::PassThrough'
,
'elementwise_op_b'
:
'ck::tensor_
operation
::
elem
ent_wise::PassThrough'
,
'elementwise_op_c'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'Gemm_spec'
:
'ck::tensor_operation::device::G
emm
S
pecialization
::Default'
,
'block_size'
:
'256'
,
'mperblock'
:
'128'
,
'nperblock'
:
'128'
,
'k0perblock'
:
'16'
,
'k1'
:
'2'
,
'm1perthread'
:
'4'
,
'n1perthread'
:
'4'
,
'kperthread'
:
'1'
,
'm1n1_thcluster_m1xs'
:
'S<8, 2>'
,
'm1n1_thcluster_n1xs'
:
'S<8, 2>'
,
'ABT_thread_slice_lengths_K0_M0_M1_K1'
:
'S<2, 1, 4, 2>'
,
'ABT_thread_cluster_lengths_K0_M0_M1_K1'
:
'S<8, 1, 32, 1>'
,
'ABT_thread_cluster_arrange_order'
:
'S<0, 3, 1, 2>'
,
'ABT_src_access_order'
:
'S<0, 3, 1, 2>'
,
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1'
:
'S<1, 1, 4, 1>'
,
'ABT_src_vec_tensor_cont_dim_order'
:
'S<0, 3, 1, 2>'
,
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1'
:
'S<1, 1, 4, 2>'
,
'BBT_thread_slice_lengths_K0_N0_N1_K1'
:
'S<2, 1, 4, 2>'
,
'BBT_thread_cluster_lengths_K0_N0_N1_K1'
:
'S<8, 1, 32, 1>'
,
'BBT_thread_cluster_arrange_order'
:
'S<0, 3, 1, 2>'
,
'BBT_src_access_order'
:
'S<0, 3, 1, 2>'
,
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1'
:
'S<1, 1, 4, 1>'
,
'BBT_src_vec_tensor_cont_dim_order'
:
'S<0, 3, 1, 2>'
,
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1'
:
'S<1, 1, 4, 2>'
,
'CTT_src_dst_access_order'
:
'S<0, 1, 2, 3, 4, 5>'
,
'CTT_src_dst_vec_dim'
:
'5'
,
'CTT_dst_scalar_per_vector'
:
'4'
'layout_a'
:
operation
.
A
.
layout
,
'layout_b'
:
operation
.
B
.
layout
,
'layout_c'
:
operation
.
C
.
layout
,
'elementwise_op_a'
:
operation
.
a_
elem
_op
,
'elementwise_op_b'
:
operation
.
b_
elem
_op
,
'elementwise_op_c'
:
operation
.
epilogue_functor
,
'Gemm_spec'
:
operation
.
g
emm
_s
pecialization
,
'block_size'
:
str
(
operation
.
tile_desc
.
block_size
)
,
'mperblock'
:
str
(
operation
.
tile_desc
.
m_per_block
)
,
'nperblock'
:
str
(
operation
.
tile_desc
.
n_per_block
)
,
'k0perblock'
:
str
(
operation
.
tile_desc
.
k_per_block
)
,
'k1'
:
str
(
operation
.
tile_desc
.
k1
)
,
'm1perthread'
:
str
(
operation
.
tile_desc
.
m_per_thread
)
,
'n1perthread'
:
str
(
operation
.
tile_desc
.
n_per_thread
)
,
'kperthread'
:
str
(
operation
.
tile_desc
.
k_per_thread
)
,
'm1n1_thcluster_m1xs'
:
operation
.
tile_desc
.
m1n1_thcluster_m1xs
,
'm1n1_thcluster_n1xs'
:
operation
.
tile_desc
.
m1n1_thcluster_n1xs
,
'ABT_thread_slice_lengths_K0_M0_M1_K1'
:
operation
.
a_block_transfer
.
thread_slice_length
,
'ABT_thread_cluster_lengths_K0_M0_M1_K1'
:
operation
.
a_block_transfer
.
thread_cluster_length
,
'ABT_thread_cluster_arrange_order'
:
operation
.
a_block_transfer
.
thread_cluster_arrange_order
,
'ABT_src_access_order'
:
operation
.
a_block_transfer
.
src_access_order
,
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1'
:
operation
.
a_block_transfer
.
src_vec_tensor_lengths
,
'ABT_src_vec_tensor_cont_dim_order'
:
operation
.
a_block_transfer
.
src_vec_tensor_cont_dim_order
,
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1'
:
operation
.
a_block_transfer
.
dst_vec_tensor_lengths
,
'BBT_thread_slice_lengths_K0_N0_N1_K1'
:
operation
.
b_block_transfer
.
thread_slice_length
,
'BBT_thread_cluster_lengths_K0_N0_N1_K1'
:
operation
.
b_block_transfer
.
thread_cluster_length
,
'BBT_thread_cluster_arrange_order'
:
operation
.
b_block_transfer
.
thread_cluster_arrange_order
,
'BBT_src_access_order'
:
operation
.
b_block_transfer
.
src_access_order
,
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1'
:
operation
.
b_block_transfer
.
src_vec_tensor_lengths
,
'BBT_src_vec_tensor_cont_dim_order'
:
operation
.
b_block_transfer
.
src_vec_tensor_cont_dim_order
,
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1'
:
operation
.
b_block_transfer
.
dst_vec_tensor_lengths
,
'CTT_src_dst_access_order'
:
operation
.
c_block_transfer
.
src_dst_access_order
,
'CTT_src_dst_vec_dim'
:
str
(
operation
.
c_block_transfer
.
src_dst_vec_dim
)
,
'CTT_dst_scalar_per_vector'
:
str
(
operation
.
c_block_transfer
.
dst_scalar_per_vector
),
}
template
=
self
.
gemm_devop_template
cf
=
open
(
"ex.cpp"
,
'w'
)
name
=
str
(
operation
.
tile_desc
.
block_size
)
cf
=
open
(
"%s.cpp"
%
name
,
'w'
)
print
(
SubstituteTemplate
(
template
,
values
))
cf
.
write
(
SubstituteTemplate
(
template
,
values
))
cf
.
close
()
m_template
=
self
.
make_template
cf
=
open
(
"Makefile"
,
'w'
)
print
(
SubstituteTemplate
(
m_template
,
values
))
cf
.
write
(
SubstituteTemplate
(
m_template
,
values
))
cf
.
close
()
PIPE
=
-
1
STDOUT
=
-
2
proc
=
subprocess
.
Popen
(
[
"make"
],
shell
=
True
,
env
=
os
.
environ
.
copy
(),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
out
,
err
=
proc
.
communicate
()
a
=
EmitGemmInstance
()
a
.
emit
()
# A = TensorDesc(DataType.f16, Layout.RowMajor)
# B = TensorDesc(DataType.f16, Layout.ColumnMajor)
# C = TensorDesc(DataType.f16, Layout.RowMajor)
# gemm = gemm_op.GemmOperation(
# A=A,
# B=B,
# C=C,
# a_elem_op=TensorOperation.PassThrough,
# b_elem_op=TensorOperation.PassThrough,
# epilogue_functor=TensorOperation.PassThrough,
# gemm_specialization=GemmType.GemmDefault,
# tile_desc=TileDesc(256, 256, 128, 32, 8, 2, 32, 32, 1, [8,2], [8,2]),
# a_block_transfer=BlockTransferDesc(
# [2, 1, 4, 2], [8, 1, 32, 1], [0, 3, 1, 2], [0, 3, 1, 2],[1, 1, 4, 1], [0, 3, 1, 2], [1, 1, 4, 2]
# ),
# b_block_transfer=BlockTransferDesc(
# [2, 1, 4, 2], [8, 1, 32, 1], [0, 3, 1, 2], [0, 3, 1, 2], [1, 1, 4, 1], [0, 3, 1, 2], [1, 1, 4, 2]
# ),
# c_block_transfer=CBlockTransferDesc([0, 1, 2, 3, 4, 5], 5, 4),
# )
# a = EmitGemmInstance()
# a.emit(gemm)
python/ait_impl/generation/ex/shared/gemm_op.py
0 → 100644
View file @
adbefd90
#take in input for gemm from user, send it to example template
import
enum
import
ck_types
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
enum
import
auto
from
typing
import
List
from
ck_types
import
*
class
GemmType
():
GemmDefault
=
"ck::tensor_operation::device::GemmSpecialization::Default"
# class GemmSpecialization(enum.Enum):
# GemmDefault = auto()
# MNKPadding = auto()
# MNPadding = auto()
# MNOPadding = auto()
# MNKOPadding = auto()
# GemmSpecializationTag = {
# GemmSpecialization.GemmDefault: "ck::tensor_operation::device::GemmSpecialization::Default",
# GemmSpecialization.MNKPadding: "ck::tensor_operation::device::GemmSpecialization::MNKPadding",
# GemmSpecialization.MNPadding: "ck::tensor_operation::device::GemmSpecialization::MNPadding",
# GemmSpecialization.MNOPadding: "ck::tensor_operation::device::GemmSpecialization::MNOPadding",
# GemmSpecialization.MNKOPadding: "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
# }
@
dataclass
class
TileDesc
:
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
k1
:
int
m_per_thread
:
int
n_per_thread
:
int
k_per_thread
:
int
m1n1_thcluster_m1xs
:
str
m1n1_thcluster_n1xs
:
str
def
__str__
(
self
)
->
str
:
values
=
list
(
self
.
__dict__
.
values
())
return
"_"
.
join
([
str
(
x
)
for
x
in
values
])
return
template
.
render
(
param
=
args
)
@
dataclass
class
BlockTransferDesc
:
thread_slice_length
:
str
thread_cluster_length
:
str
thread_cluster_arrange_order
:
str
src_access_order
:
str
src_vec_tensor_lengths
:
str
src_vec_tensor_cont_dim_order
:
str
dst_vec_tensor_lengths
:
str
def
__str__
(
self
)
->
str
:
args
=
deepcopy
(
self
.
__dict__
)
args
[
"thread_cluster_length"
]
=
[
str
(
x
)
for
x
in
self
.
thread_cluster_length
]
args
[
"thread_cluster_arrange_order"
]
=
[
str
(
x
)
for
x
in
self
.
thread_cluster_arrange_order
]
args
[
"src_access_order"
]
=
[
str
(
x
)
for
x
in
self
.
src_access_order
]
@
dataclass
class
CBlockTransferDesc
:
src_dst_access_order
:
str
src_dst_vec_dim
:
int
dst_scalar_per_vector
:
int
def
__str__
(
self
)
->
str
:
args
=
deepcopy
(
self
.
__dict__
)
#args["m_n_block_wave_per_xdl"] = [str(x) for x in self.m_n_block_wave_per_xdl]
@
dataclass
class
GemmOperation
:
A
:
TensorDesc
B
:
TensorDesc
C
:
TensorDesc
a_elem_op
:
TensorOperation
b_elem_op
:
TensorOperation
epilogue_functor
:
TensorOperation
gemm_specialization
:
GemmType
#GemmSpecialization
tile_desc
:
TileDesc
a_block_transfer
:
BlockTransferDesc
b_block_transfer
:
BlockTransferDesc
b1_block_transfer
:
BlockTransferDesc
=
None
c_block_transfer
:
CBlockTransferDesc
=
None
def
__str__
(
self
)
->
str
:
io_name
=
"{gemm_kind}_{gemm_specialization}_{a_dtype}{b_dtype}{c_dtype}_{a_layout}{b_layout}{c_layout}"
.
format
(
#gemm_kind=library.GemmKindNames[self.operation_kind],
gemm_specialization
=
self
.
gemm_specialization
.
value
,
a_dtype
=
[
self
.
A
.
element
],
b_dtype
=
[
self
.
B
.
element
],
c_dtype
=
[
self
.
C
.
element
],
a_layout
=
[
self
.
A
.
layout
],
b_layout
=
[
self
.
B
.
layout
],
c_layout
=
[
self
.
C
.
layout
],
)
extra_tile
=
""
if
self
.
c_block_transfer
is
not
None
:
if
self
.
c_block_transfer
.
scalar_per_vector
==
4
:
extra_tile
=
"_C4"
elif
self
.
c_block_transfer
.
scalar_per_vector
==
1
:
extra_tile
=
"_C1"
tile_name
=
str
(
self
.
tile_desc
)
+
extra_tile
return
"{io_name}_{tile_name}_{epilogue_functor}"
.
format
(
io_name
=
io_name
,
tile_name
=
tile_name
,
epilogue_functor
=
[
self
.
epilogue_functor
],
)
def
accumulator_type
(
self
):
return
DataType
.
f16
#f.32?
if
__name__
==
"__main__"
:
A
=
TensorDesc
(
DataType
.
f16
,
Layout
.
RowMajor
)
B
=
TensorDesc
(
DataType
.
f16
,
Layout
.
ColumnMajor
)
C
=
TensorDesc
(
DataType
.
f16
,
Layout
.
RowMajor
)
GemmOp
=
GemmOperation
(
A
=
A
,
B
=
B
,
C
=
C
,
a_elem_op
=
TensorOperation
.
PassThrough
,
b_elem_op
=
TensorOperation
.
PassThrough
,
epilogue_functor
=
TensorOperation
.
PassThrough
,
gemm_specialization
=
GemmType
.
GemmDefault
,
tile_desc
=
TileDesc
(
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
),
a_block_transfer
=
BlockTransferDesc
(
[
4
,
64
,
1
],
[
1
,
0
,
2
],
[
1
,
0
,
2
],
2
,
8
,
8
,
1
,
True
),
b_block_transfer
=
BlockTransferDesc
(
[
8
,
32
,
1
],
[
0
,
2
,
1
],
[
0
,
2
,
1
],
1
,
4
,
1
,
0
,
True
),
c_block_transfer
=
CBlockTransferDesc
(
1
,
1
,
[
1
,
32
,
1
,
8
],
8
),
#ds_dtype=[DataType.f16],
)
print
(
GemmOp
.
a_elem_op
)
python/ait_impl/generation/ex/shared/make_template.py
0 → 100644
View file @
adbefd90
import
enum
import
os.path
import
shutil
import
functools
import
operator
import
collections
import
subprocess
import
re
import
gemm_op
from
gemm_op
import
*
import
user
def
SubstituteTemplate
(
template
,
values
):
text
=
template
changed
=
True
while
changed
:
changed
=
False
for
key
,
value
in
values
.
items
():
regex
=
"
\\
$
\\
{%s
\\
}"
%
key
newtext
=
re
.
sub
(
regex
,
value
,
text
)
if
newtext
!=
text
:
changed
=
True
text
=
newtext
return
text
class
EmitMake
:
def
__init__
(
self
):
self
.
make_template
=
"""
CFLAGS=-I ~/workspace/composable_kernel/include -I /opt/workspace/rocm-5.1.1/hip/include -I ~/workspace/composable_kernel/include/ -I ~/workspace/composable_kernel/include/ck/ -I ~/workspace/composable_kernel/example/01_gemm/ -I ~/workspace/composable_kernel/library/include/ -I ~/workspace/composable_kernel/library/src/utility/ -I ~/workspace/composable_kernel/include/ck/problem_transform/ -I ~/workspace/composable_kernel/include/ck/tensor/ -I ~/workspace/composable_kernel/include/ck/tensor_description/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/workspace/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/workspace/composable_kernel/library/include/ck/library/host/ -I ~/workspace/composable_kernel/library/include/ck/library/host_tensor/ -I ~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_op/ -I ~/workspace/composable_kernel/library/include/ck/library/utility/ -I ~/workspace/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
device_memory.o: ../../../../../library/src/utility/device_memory.cpp
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) $(CFLAGS) -c ../../../../../library/src/utility/device_memory.cpp
host_tensor.o: ../../../../../library/src/utility/host_tensor.cpp
hipcc -shared -fPIC -fvisibility=hidden $(CXXFLAGS) $(CFLAGS) -c ../../../../../library/src/utility/host_tensor.cpp
obj_files = 256.o
%.o : %.cpp
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c $<
all: test.so
test.so: $(obj_files) host_tensor.o device_memory.o
hipcc -shared $(CXXFLAGS) $(CFLAGS) -o $@ $(obj_files) host_tensor.o device_memory.o
clean:
rm -f *.o test.so
"""
def
emit
(
self
,
operation
):
values
=
{
'temp'
:
""
}
m_template
=
self
.
make_template
cf
=
open
(
"Makefile"
,
'w'
)
print
(
SubstituteTemplate
(
m_template
,
values
))
cf
.
write
(
SubstituteTemplate
(
m_template
,
values
))
cf
.
close
()
PIPE
=
-
1
STDOUT
=
-
2
proc
=
subprocess
.
Popen
(
[
"make"
],
shell
=
True
,
env
=
os
.
environ
.
copy
(),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
)
out
,
err
=
proc
.
communicate
()
\ No newline at end of file
python/ait_impl/generation/ex/shared/user.py
0 → 100644
View file @
adbefd90
import
gemm_op
as
gemm
import
enum
from
dataclasses
import
dataclass
from
enum
import
auto
import
ck_types
from
ck_types
import
*
def
CreateGemmOperator
():
#operation_kind = library.GemmKind.Gemm
a_element_desc
=
TensorDesc
(
DataType
.
f16
,
Layout
.
ColumnMajor
)
b_element_desc
=
TensorDesc
(
DataType
.
f16
,
Layout
.
RowMajor
)
c_element_desc
=
TensorDesc
(
DataType
.
f16
,
Layout
.
RowMajor
)
element_op
=
TensorOperation
.
PassThrough
tile_descriptions
=
[
gemm
.
TileDesc
(
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
"S<8, 2>"
,
"S<8, 2>"
),
gemm
.
TileDesc
(
128
,
128
,
128
,
32
,
2
,
32
,
32
,
1
,
"S<8, 2>"
,
"S<8, 2>"
),
# gemm.TileDesc(256, 128, 256, 32, 8, 2, 32, 32, 2, 4),
# gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4),
# gemm.TileDesc(128, 128, 128, 32, 8, 2, 32, 32, 4, 2),
# gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2),
# gemm.TileDesc(256, 128, 128, 32, 8, 2, 32, 32, 2, 2),
# gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2),
# gemm.TileDesc(128, 128, 64, 32, 8, 2, 32, 32, 2, 2),
# gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2),
# gemm.TileDesc(128, 64, 128, 32, 8, 2, 32, 32, 2, 2),
# gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2),
# gemm.TileDesc(256, 128, 64, 32, 8, 2, 32, 32, 2, 1),
# gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1),
# gemm.TileDesc(256, 64, 128, 32, 8, 2, 32, 32, 1, 2),
# gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2),
]
b_block_descriptions
=
[
gemm
.
BlockTransferDesc
(
"S<2, 1, 4, 2>"
,
"S<8, 1, 32, 1>"
,
"S<0, 3, 1, 2>"
,
"S<0, 3, 1, 2>"
,
"S<1, 1, 4, 1>"
,
"S<0, 3, 1, 2>"
,
"S<1, 1, 4, 2>"
),
gemm
.
BlockTransferDesc
(
"S<2, 1, 4, 2>"
,
"S<8, 1, 32, 1>"
,
"S<0, 3, 1, 2>"
,
"S<0, 3, 1, 2>"
,
"S<1, 1, 4, 1>"
,
"S<0, 3, 1, 2>"
,
"S<1, 1, 4, 2>"
),
# gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1),
# gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1),
# gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1),
# gemm.BlockTransferDesc([8, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1),
# gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 8, 1),
# gemm.BlockTransferDesc([16, 16, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 1, 8, 1),
# gemm.BlockTransferDesc([8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 2, 0),
# gemm.BlockTransferDesc([4, 64, 1], [0, 2, 1], [0, 2, 1], 1, 2, 8, 1),
]
c_block_descriptions
=
[
gemm
.
CBlockTransferDesc
(
"S<0, 1, 2, 3, 4, 5>"
,
5
,
4
),
gemm
.
CBlockTransferDesc
(
"S<1, 2, 3, 5, 5, 6>"
,
6
,
5
),
]
a_block_descriptions
=
b_block_descriptions
#c_block_descriptions = []
# AIT logic, adapt later
# for t in tile_descriptions:
# a_block_transfer = -1
# c_block_transfer = -1
# if t.block_size == 256:
# a_block_transfer = [4, 64, 1]
# c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8)
# if t.block_size == 128:
# a_block_transfer = [4, 32, 1]
# if t.n_per_block == 128:
# c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8)
# if t.n_per_block == 64:
# c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8)
# assert (
# a_block_transfer != -1
# and c_block_transfer != -1
# and "Cannot determine block_transfer_size with block_size "
# + str(t.block_size)
# )
# a_block_descriptions.append(
# gemm.BlockTransferDesc(a_block_transfer, [1, 0, 2], [1, 0, 2], 2, 8, 8, 1)
# )
# c_block_descriptions.append(c_block_transfer)
gemm_specialization
=
[
gemm
.
GemmType
.
GemmDefault
]
operations
=
[]
for
gemm_spec
in
gemm_specialization
:
for
tile_desc
,
a_block_desc
,
b_block_desc
,
c_block_desc
in
zip
(
tile_descriptions
,
a_block_descriptions
,
b_block_descriptions
,
c_block_descriptions
,
):
new_operation
=
gemm
.
GemmOperation
(
#operation_kind=operation_kind,
A
=
a_element_desc
,
B
=
b_element_desc
,
C
=
c_element_desc
,
a_elem_op
=
element_op
,
b_elem_op
=
element_op
,
epilogue_functor
=
element_op
,
gemm_specialization
=
gemm_spec
,
tile_desc
=
tile_desc
,
a_block_transfer
=
a_block_desc
,
b_block_transfer
=
b_block_desc
,
c_block_transfer
=
c_block_desc
,
)
#manifest.append(new_operation)
operations
.
append
(
new_operation
)
return
operations
print
(
operations
[
0
].
tile_desc
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment