Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8ff88928
Commit
8ff88928
authored
Sep 14, 2023
by
Astha Rai
Browse files
cleaned up code, added comments
parent
5714d3c6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
9 deletions
+25
-9
instance_gen/AIT_impl/generation/instance/ck_types.py
instance_gen/AIT_impl/generation/instance/ck_types.py
+3
-0
instance_gen/AIT_impl/generation/instance/driver.py
instance_gen/AIT_impl/generation/instance/driver.py
+2
-1
instance_gen/AIT_impl/generation/instance/gemm_ex.py
instance_gen/AIT_impl/generation/instance/gemm_ex.py
+11
-6
instance_gen/AIT_impl/generation/instance/gemm_op.py
instance_gen/AIT_impl/generation/instance/gemm_op.py
+1
-2
instance_gen/AIT_impl/generation/instance/user.py
instance_gen/AIT_impl/generation/instance/user.py
+8
-0
No files found.
instance_gen/AIT_impl/generation/instance/ck_types.py
View file @
8ff88928
...
@@ -14,6 +14,9 @@ class TensorOperation:
...
@@ -14,6 +14,9 @@ class TensorOperation:
PassThrough
=
"PassThrough"
PassThrough
=
"PassThrough"
Bilinear
=
"Bilinear"
Bilinear
=
"Bilinear"
class
GemmType
():
GemmDefault
=
"ck::tensor_operation::device::GemmSpecialization::Default"
@
dataclass
@
dataclass
class
TensorDesc
:
#set up and import properly
class
TensorDesc
:
#set up and import properly
element
:
DataType
element
:
DataType
...
...
instance_gen/AIT_impl/generation/instance/driver.py
View file @
8ff88928
...
@@ -16,10 +16,11 @@ from gemm_op import *
...
@@ -16,10 +16,11 @@ from gemm_op import *
import
user
import
user
from
ck_types
import
*
from
ck_types
import
*
from
gemm_ex
import
*
from
gemm_ex
import
*
#from make_template import *
# holds multiple gemm instances
# holds multiple gemm instances
op_collection
=
user
.
CreateGemmOperator
()
op_collection
=
user
.
CreateGemmOperator
()
# emit for each instance
for
op
in
op_collection
:
for
op
in
op_collection
:
x
=
EmitGemmInstance
()
x
=
EmitGemmInstance
()
x
.
emit
(
op
)
x
.
emit
(
op
)
...
...
instance_gen/AIT_impl/generation/instance/gemm_ex.py
View file @
8ff88928
...
@@ -10,6 +10,7 @@ import gemm_op
...
@@ -10,6 +10,7 @@ import gemm_op
from
gemm_op
import
*
from
gemm_op
import
*
import
user
import
user
# function to substitute values into template
def
SubstituteTemplate
(
template
,
values
):
def
SubstituteTemplate
(
template
,
values
):
text
=
template
text
=
template
changed
=
True
changed
=
True
...
@@ -23,7 +24,7 @@ def SubstituteTemplate(template, values):
...
@@ -23,7 +24,7 @@ def SubstituteTemplate(template, values):
text
=
newtext
text
=
newtext
return
text
return
text
# setting up the template with all the user input
class
EmitGemmInstance
:
class
EmitGemmInstance
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
gemm_op_template
=
"""
self
.
gemm_op_template
=
"""
...
@@ -31,6 +32,8 @@ class EmitGemmInstance:
...
@@ -31,6 +32,8 @@ class EmitGemmInstance:
DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>,
DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>,
"""
"""
# function that takes in operation from gemm_op and gets tuning parameters
def
emit
(
self
,
operation
):
def
emit
(
self
,
operation
):
#name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1))
#name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1))
values
=
{
values
=
{
...
@@ -42,7 +45,7 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
...
@@ -42,7 +45,7 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
'type_a'
:
operation
.
A
.
element
,
'type_a'
:
operation
.
A
.
element
,
'type_b'
:
operation
.
B
.
element
,
'type_b'
:
operation
.
B
.
element
,
'type_acc'
:
operation
.
acc
,
'type_acc'
:
operation
.
acc
,
'type_cshuffle'
:
operation
.
cs_type
,
#figure out how to arrange this
'type_cshuffle'
:
operation
.
cs_type
,
'type_ds'
:
operation
.
Ds
.
element
,
'type_ds'
:
operation
.
Ds
.
element
,
'type_e'
:
operation
.
E
.
element
,
'type_e'
:
operation
.
E
.
element
,
'elementwise_op_a'
:
operation
.
a_elem_op
,
'elementwise_op_a'
:
operation
.
a_elem_op
,
...
@@ -79,13 +82,15 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
...
@@ -79,13 +82,15 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl'
:
operation
.
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
,
'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl'
:
operation
.
c_block_transfer
.
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl
,
'CTT_scalar_per_vector_n_wave_n_per_Xdl'
:
str
(
operation
.
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
),
'CTT_scalar_per_vector_n_wave_n_per_Xdl'
:
str
(
operation
.
c_block_transfer
.
scalar_per_vector_n_wave_n_per_Xdl
),
}
}
template
=
self
.
gemm_op_template
name
=
(
str
(
operation
.
tile_desc
.
block_size
)
+
"_"
+
str
(
operation
.
tile_desc
.
m_per_block
)
+
"_"
+
str
(
operation
.
tile_desc
.
n_per_block
)
+
"_"
+
str
(
operation
.
tile_desc
.
k_per_block
)
+
"_"
+
str
(
operation
.
tile_desc
.
ak1
))
# print(SubstituteTemplate(template, values))
# name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
# + "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# defining the template to use and substituting the values
template
=
self
.
gemm_op_template
instances
=
SubstituteTemplate
(
template
,
values
)
instances
=
SubstituteTemplate
(
template
,
values
)
print
(
instances
)
print
(
instances
)
# cf = open("instances.cpp",'w')
# cf = open("instances.cpp",'w')
# cf.write(SubstituteTemplate(template, values))
# cf.write(SubstituteTemplate(template, values))
# cf.close()
# cf.close()
...
...
instance_gen/AIT_impl/generation/instance/gemm_op.py
View file @
8ff88928
...
@@ -9,8 +9,6 @@ from enum import auto
...
@@ -9,8 +9,6 @@ from enum import auto
from
typing
import
List
from
typing
import
List
from
ck_types
import
*
from
ck_types
import
*
class
GemmType
():
GemmDefault
=
"ck::tensor_operation::device::GemmSpecialization::Default"
@
dataclass
@
dataclass
class
TileDesc
:
class
TileDesc
:
...
@@ -88,3 +86,4 @@ class GemmOperation:
...
@@ -88,3 +86,4 @@ class GemmOperation:
a_layout
=
[
self
.
A
.
layout
],
a_layout
=
[
self
.
A
.
layout
],
b_layout
=
[
self
.
B
.
layout
],
b_layout
=
[
self
.
B
.
layout
],
)
)
instance_gen/AIT_impl/generation/instance/user.py
View file @
8ff88928
...
@@ -28,6 +28,8 @@ def CreateGemmOperator():
...
@@ -28,6 +28,8 @@ def CreateGemmOperator():
acc_type
=
DataType
.
f16
acc_type
=
DataType
.
f16
cshuffle_type
=
DataType
.
f32
cshuffle_type
=
DataType
.
f32
# Tile Desc: (block_size, m_per_block, n_per_block, k_per_block, ak1, bk1,
# m_per_XDL, n_per_XDL, m_Xdl_per_wave, n_Xdl_per_wave, num_gemmk_prefetch_stage)
tile_descriptions
=
[
tile_descriptions
=
[
gemm
.
TileDesc
(
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
1
),
gemm
.
TileDesc
(
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
1
),
gemm
.
TileDesc
(
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
1
),
gemm
.
TileDesc
(
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
1
),
...
@@ -44,6 +46,8 @@ def CreateGemmOperator():
...
@@ -44,6 +46,8 @@ def CreateGemmOperator():
gemm
.
TileDesc
(
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
),
gemm
.
TileDesc
(
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
1
),
]
]
# BlockTransferDesc: (thread_cluster_length, thread_cluster_arrange_order, src_access_order, src_vec_dim,
# src_scalar_per_vector, dst_scalar_per_vector_k1, lds_add_extra_dim )
a_block_descriptions
=
[
a_block_descriptions
=
[
gemm
.
BlockTransferDesc
(
"S<4, 64, 1>"
,
"S<1, 0, 2>"
,
"S<1, 0, 2>"
,
2
,
8
,
8
,
1
),
gemm
.
BlockTransferDesc
(
"S<4, 64, 1>"
,
"S<1, 0, 2>"
,
"S<1, 0, 2>"
,
2
,
8
,
8
,
1
),
gemm
.
BlockTransferDesc
(
"S<4, 64, 1>"
,
"S<1, 0, 2>"
,
"S<1, 0, 2>"
,
2
,
8
,
8
,
1
),
gemm
.
BlockTransferDesc
(
"S<4, 64, 1>"
,
"S<1, 0, 2>"
,
"S<1, 0, 2>"
,
2
,
8
,
8
,
1
),
...
@@ -76,6 +80,7 @@ def CreateGemmOperator():
...
@@ -76,6 +80,7 @@ def CreateGemmOperator():
gemm
.
BlockTransferDesc
(
"S<4, 16, 1>"
,
"S<1, 0, 2>"
,
"S<1, 0, 2>"
,
2
,
8
,
8
,
1
),
gemm
.
BlockTransferDesc
(
"S<4, 16, 1>"
,
"S<1, 0, 2>"
,
"S<1, 0, 2>"
,
2
,
8
,
8
,
1
),
]
]
# cshuffle_descriptions: (m_Xdl_per_wave_per_shuffle, n_Xdl_per_wave_per_shuffle)
cshuffle_descriptions
=
[
cshuffle_descriptions
=
[
gemm
.
CShuffleDesc
(
1
,
1
),
gemm
.
CShuffleDesc
(
1
,
1
),
gemm
.
CShuffleDesc
(
1
,
1
),
gemm
.
CShuffleDesc
(
1
,
1
),
...
@@ -91,6 +96,7 @@ def CreateGemmOperator():
...
@@ -91,6 +96,7 @@ def CreateGemmOperator():
gemm
.
CShuffleDesc
(
1
,
1
),
gemm
.
CShuffleDesc
(
1
,
1
),
]
]
# CBlockTransferDesc: (cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl, scalar_per_vector_n_wave_n_per_Xdl)
c_block_descriptions
=
[
c_block_descriptions
=
[
gemm
.
CBlockTransferDesc
(
"S<1, 32, 1, 8>"
,
8
),
gemm
.
CBlockTransferDesc
(
"S<1, 32, 1, 8>"
,
8
),
gemm
.
CBlockTransferDesc
(
"S<1, 32, 1, 8>"
,
8
),
gemm
.
CBlockTransferDesc
(
"S<1, 32, 1, 8>"
,
8
),
...
@@ -111,6 +117,8 @@ def CreateGemmOperator():
...
@@ -111,6 +117,8 @@ def CreateGemmOperator():
gemm_specialization
=
[
gemm_specialization
=
[
gemm
.
GemmType
.
GemmDefault
gemm
.
GemmType
.
GemmDefault
]
]
# set up and return list of instances using ^tuning parameters
operations
=
[]
operations
=
[]
for
gemm_spec
in
gemm_specialization
:
for
gemm_spec
in
gemm_specialization
:
for
tile_desc
,
a_block_desc
,
b_block_desc
,
cshuffle_desc
,
c_block_desc
in
zip
(
for
tile_desc
,
a_block_desc
,
b_block_desc
,
cshuffle_desc
,
c_block_desc
in
zip
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment