Commit 8ff88928 authored by Astha Rai's avatar Astha Rai
Browse files

cleaned up code, added comments

parent 5714d3c6
...@@ -14,6 +14,9 @@ class TensorOperation: ...@@ -14,6 +14,9 @@ class TensorOperation:
PassThrough = "PassThrough" PassThrough = "PassThrough"
Bilinear = "Bilinear" Bilinear = "Bilinear"
class GemmType():
GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"
@dataclass @dataclass
class TensorDesc: #set up and import properly class TensorDesc: #set up and import properly
element: DataType element: DataType
......
...@@ -16,10 +16,11 @@ from gemm_op import * ...@@ -16,10 +16,11 @@ from gemm_op import *
import user import user
from ck_types import * from ck_types import *
from gemm_ex import * from gemm_ex import *
#from make_template import *
# holds multiple gemm instances # holds multiple gemm instances
op_collection = user.CreateGemmOperator() op_collection = user.CreateGemmOperator()
# emit for each instance
for op in op_collection: for op in op_collection:
x = EmitGemmInstance() x = EmitGemmInstance()
x.emit(op) x.emit(op)
......
...@@ -10,6 +10,7 @@ import gemm_op ...@@ -10,6 +10,7 @@ import gemm_op
from gemm_op import * from gemm_op import *
import user import user
# function to substitute values into template
def SubstituteTemplate(template, values): def SubstituteTemplate(template, values):
text = template text = template
changed = True changed = True
...@@ -23,7 +24,7 @@ def SubstituteTemplate(template, values): ...@@ -23,7 +24,7 @@ def SubstituteTemplate(template, values):
text = newtext text = newtext
return text return text
# setting up the template with all the user input
class EmitGemmInstance: class EmitGemmInstance:
def __init__(self): def __init__(self):
self.gemm_op_template = """ self.gemm_op_template = """
...@@ -31,6 +32,8 @@ class EmitGemmInstance: ...@@ -31,6 +32,8 @@ class EmitGemmInstance:
DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>, DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>,
""" """
# function that takes in operation from gemm_op and gets tuning parameters
def emit(self,operation): def emit(self,operation):
#name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1)) #name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1))
values = { values = {
...@@ -42,7 +45,7 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou ...@@ -42,7 +45,7 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
'type_a' : operation.A.element, 'type_a' : operation.A.element,
'type_b' : operation.B.element, 'type_b' : operation.B.element,
'type_acc' : operation.acc, 'type_acc' : operation.acc,
'type_cshuffle' : operation.cs_type, #figure out how to arrange this 'type_cshuffle' : operation.cs_type,
'type_ds' : operation.Ds.element, 'type_ds' : operation.Ds.element,
'type_e' : operation.E.element, 'type_e' : operation.E.element,
'elementwise_op_a' : operation.a_elem_op, 'elementwise_op_a' : operation.a_elem_op,
...@@ -79,13 +82,15 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou ...@@ -79,13 +82,15 @@ DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layou
'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl' : operation.c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl, 'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl' : operation.c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl,
'CTT_scalar_per_vector_n_wave_n_per_Xdl' : str(operation.c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl), 'CTT_scalar_per_vector_n_wave_n_per_Xdl' : str(operation.c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl),
} }
template = self.gemm_op_template
name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
+ "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# print(SubstituteTemplate(template, values)) # name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
# + "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# defining the template to use and substituting the values
template = self.gemm_op_template
instances = SubstituteTemplate(template, values) instances = SubstituteTemplate(template, values)
print(instances) print(instances)
# cf = open("instances.cpp",'w') # cf = open("instances.cpp",'w')
# cf.write(SubstituteTemplate(template, values)) # cf.write(SubstituteTemplate(template, values))
# cf.close() # cf.close()
......
...@@ -9,8 +9,6 @@ from enum import auto ...@@ -9,8 +9,6 @@ from enum import auto
from typing import List from typing import List
from ck_types import * from ck_types import *
class GemmType():
GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"
@dataclass @dataclass
class TileDesc: class TileDesc:
...@@ -88,3 +86,4 @@ class GemmOperation: ...@@ -88,3 +86,4 @@ class GemmOperation:
a_layout=[self.A.layout], a_layout=[self.A.layout],
b_layout=[self.B.layout], b_layout=[self.B.layout],
) )
...@@ -28,6 +28,8 @@ def CreateGemmOperator(): ...@@ -28,6 +28,8 @@ def CreateGemmOperator():
acc_type = DataType.f16 acc_type = DataType.f16
cshuffle_type = DataType.f32 cshuffle_type = DataType.f32
# Tile Desc: (block_size, m_per_block, n_per_block, k_per_block, ak1, bk1,
# m_per_XDL, n_per_XDL, m_Xdl_per_wave, n_Xdl_per_wave, num_gemmk_prefetch_stage)
tile_descriptions = [ tile_descriptions = [
gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1), gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1),
gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1), gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1),
...@@ -44,6 +46,8 @@ def CreateGemmOperator(): ...@@ -44,6 +46,8 @@ def CreateGemmOperator():
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2, 1), gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2, 1),
] ]
# BlockTransferDesc: (thread_cluster_length, thread_cluster_arrange_order, src_access_order, src_vec_dim,
# src_scalar_per_vector, dst_scalar_per_vector_k1, lds_add_extra_dim )
a_block_descriptions = [ a_block_descriptions = [
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1), gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1), gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
...@@ -76,6 +80,7 @@ def CreateGemmOperator(): ...@@ -76,6 +80,7 @@ def CreateGemmOperator():
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1), gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
] ]
# cshuffle_descriptions: (m_Xdl_per_wave_per_shuffle, n_Xdl_per_wave_per_shuffle)
cshuffle_descriptions = [ cshuffle_descriptions = [
gemm.CShuffleDesc(1,1), gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1), gemm.CShuffleDesc(1,1),
...@@ -91,6 +96,7 @@ def CreateGemmOperator(): ...@@ -91,6 +96,7 @@ def CreateGemmOperator():
gemm.CShuffleDesc(1,1), gemm.CShuffleDesc(1,1),
] ]
# CBlockTransferDesc: (cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl, scalar_per_vector_n_wave_n_per_Xdl)
c_block_descriptions = [ c_block_descriptions = [
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8), gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8), gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
...@@ -111,6 +117,8 @@ def CreateGemmOperator(): ...@@ -111,6 +117,8 @@ def CreateGemmOperator():
gemm_specialization = [ gemm_specialization = [
gemm.GemmType.GemmDefault gemm.GemmType.GemmDefault
] ]
# set up and return list of instances using ^tuning parameters
operations = [] operations = []
for gemm_spec in gemm_specialization: for gemm_spec in gemm_specialization:
for tile_desc, a_block_desc, b_block_desc, cshuffle_desc, c_block_desc in zip( for tile_desc, a_block_desc, b_block_desc, cshuffle_desc, c_block_desc in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment