Commit 8ff88928 authored by Astha Rai's avatar Astha Rai
Browse files

cleaned up code, added comments

parent 5714d3c6
......@@ -14,6 +14,9 @@ class TensorOperation:
PassThrough = "PassThrough"
Bilinear = "Bilinear"
class GemmType():
GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"
@dataclass
class TensorDesc: #set up and import properly
element: DataType
......
......@@ -16,10 +16,11 @@ from gemm_op import *
import user
from ck_types import *
from gemm_ex import *
#from make_template import *
# holds multiple gemm instances
op_collection = user.CreateGemmOperator()
# emit for each instance
for op in op_collection:
x = EmitGemmInstance()
x.emit(op)
......
......@@ -10,6 +10,7 @@ import gemm_op
from gemm_op import *
import user
# function to substitute values into template
def SubstituteTemplate(template, values):
text = template
changed = True
......@@ -23,7 +24,7 @@ def SubstituteTemplate(template, values):
text = newtext
return text
# setting up the template with all the user input
class EmitGemmInstance:
def __init__(self):
self.gemm_op_template = """
......@@ -31,6 +32,8 @@ class EmitGemmInstance:
DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>,
"""
# function that takes in operation from gemm_op and gets tuning parameters
def emit(self,operation):
#name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1))
values = {
......@@ -42,7 +45,7 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
'type_a' : operation.A.element,
'type_b' : operation.B.element,
'type_acc' : operation.acc,
'type_cshuffle' : operation.cs_type, #figure out how to arrange this
'type_cshuffle' : operation.cs_type,
'type_ds' : operation.Ds.element,
'type_e' : operation.E.element,
'elementwise_op_a' : operation.a_elem_op,
......@@ -79,13 +82,15 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl' : operation.c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl,
'CTT_scalar_per_vector_n_wave_n_per_Xdl' : str(operation.c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl),
}
template = self.gemm_op_template
name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
+ "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# print(SubstituteTemplate(template, values))
# name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
# + "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# defining the template to use and substituting the values
template = self.gemm_op_template
instances = SubstituteTemplate(template, values)
print(instances)
# cf = open("instances.cpp",'w')
# cf.write(SubstituteTemplate(template, values))
# cf.close()
......
......@@ -9,8 +9,6 @@ from enum import auto
from typing import List
from ck_types import *
class GemmType():
GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"
@dataclass
class TileDesc:
......@@ -88,3 +86,4 @@ class GemmOperation:
a_layout=[self.A.layout],
b_layout=[self.B.layout],
)
......@@ -28,6 +28,8 @@ def CreateGemmOperator():
acc_type = DataType.f16
cshuffle_type = DataType.f32
# Tile Desc: (block_size, m_per_block, n_per_block, k_per_block, ak1, bk1,
# m_per_XDL, n_per_XDL, m_Xdl_per_wave, n_Xdl_per_wave, num_gemmk_prefetch_stage)
tile_descriptions = [
gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1),
gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1),
......@@ -44,6 +46,8 @@ def CreateGemmOperator():
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2, 1),
]
# BlockTransferDesc: (thread_cluster_length, thread_cluster_arrange_order, src_access_order, src_vec_dim,
# src_scalar_per_vector, dst_scalar_per_vector_k1, lds_add_extra_dim )
a_block_descriptions = [
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
......@@ -76,6 +80,7 @@ def CreateGemmOperator():
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
]
# cshuffle_descriptions: (m_Xdl_per_wave_per_shuffle, n_Xdl_per_wave_per_shuffle)
cshuffle_descriptions = [
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
......@@ -91,6 +96,7 @@ def CreateGemmOperator():
gemm.CShuffleDesc(1,1),
]
# CBlockTransferDesc: (cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl, scalar_per_vector_n_wave_n_per_Xdl)
c_block_descriptions = [
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
......@@ -111,6 +117,8 @@ def CreateGemmOperator():
gemm_specialization = [
gemm.GemmType.GemmDefault
]
# set up and return list of instances using ^tuning parameters
operations = []
for gemm_spec in gemm_specialization:
for tile_desc, a_block_desc, b_block_desc, cshuffle_desc, c_block_desc in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment