"test/srt/vscode:/vscode.git/clone" did not exist on "dca87ec34801e4a541cf8324e977522a2b06c067"
Commit 5714d3c6 authored by Astha Rai's avatar Astha Rai
Browse files

added preliminary version of instance generator for MIGraphX - generates...

added preliminary version of instance generator for MIGraphX - generates string of instances for device_gemm_bilinear example
parent 475188ca
from dataclasses import dataclass
class DataType:
f16 = "F16"
f32 = "F32"
f16_tuple = "F16_Tuple"
class Layout:
ColumnMajor = "Col"
RowMajor = "Row"
Row_Tuple = "Row_Tuple"
class TensorOperation:
PassThrough = "PassThrough"
Bilinear = "Bilinear"
@dataclass
class TensorDesc: #set up and import properly
element: DataType
layout: Layout
import enum
import ck_types
from copy import deepcopy
from dataclasses import dataclass
from enum import auto
from typing import List
import os.path
import shutil
import functools
import operator
import collections
import subprocess
import re
import gemm_op
from gemm_op import *
import user
from ck_types import *
from gemm_ex import *
#from make_template import *
# holds multiple gemm instances
op_collection = user.CreateGemmOperator()
for op in op_collection:
x = EmitGemmInstance()
x.emit(op)
import enum
import os.path
import shutil
import functools
import operator
import collections
import subprocess
import re
import gemm_op
from gemm_op import *
import user
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_op_template = """
DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>,
"""
def emit(self,operation):
#name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1))
values = {
#'name' : name,
'layout_a' : operation.A.layout,
'layout_b' : operation.B.layout,
'layout_ds' : operation.Ds.layout,
'layout_e' : operation.E.layout,
'type_a' : operation.A.element,
'type_b' : operation.B.element,
'type_acc' : operation.acc,
'type_cshuffle' : operation.cs_type, #figure out how to arrange this
'type_ds' : operation.Ds.element,
'type_e' : operation.E.element,
'elementwise_op_a' : operation.a_elem_op,
'elementwise_op_b' : operation.b_elem_op,
'elementwise_op_cde' : operation.cde_elem_op,
'Gemm_spec' : operation.gemm_specialization,
'num_gemmk_prefetch_stage' : str(operation.tile_desc.num_gemmk_prefetch_stage),
'block_size' : str(operation.tile_desc.block_size),
'mperblock' : str(operation.tile_desc.m_per_block),
'nperblock' : str(operation.tile_desc.n_per_block),
'kperblock' : str(operation.tile_desc.k_per_block),
'ak1' : str(operation.tile_desc.ak1),
'bk1' : str(operation.tile_desc.bk1),
'mperXDL' : str(operation.tile_desc.m_per_XDL),
'nperXDL' : str(operation.tile_desc.n_per_XDL),
'mXdlperwave' : str(operation.tile_desc.m_Xdl_per_wave),
'nXdlperwave' : str(operation.tile_desc.n_Xdl_per_wave),
'ABT_thread_cluster_lengths_K0_M_K1' : operation.a_block_transfer.thread_cluster_length,
'ABT_thread_cluster_arrange_order' : operation.a_block_transfer.thread_cluster_arrange_order,
'ABT_src_access_order' : operation.a_block_transfer.src_access_order,
'ABT_src_vec_dim' : str(operation.a_block_transfer.src_vec_dim),
'ABT_src_scalar_per_vec' : str(operation.a_block_transfer.src_scalar_per_vector),
'ABT_dst_scalar_per_vec_k1' : str(operation.a_block_transfer.dst_scalar_per_vector_k1),
'ABT_lds_add_extra_m' : str(operation.a_block_transfer.lds_add_extra_dim),
'BBT_thread_cluster_lengths_K0_N_K1' : operation.b_block_transfer.thread_cluster_length,
'BBT_thread_cluster_arrange_order' : operation.b_block_transfer.thread_cluster_arrange_order,
'BBT_src_access_order' : operation.b_block_transfer.src_access_order,
'BBT_src_vec_dim' : str(operation.b_block_transfer.src_vec_dim),
'BBT_src_scalar_per_vec' : str(operation.b_block_transfer.src_scalar_per_vector),
'BBT_dst_scalar_per_vec_k1' : str(operation.b_block_transfer.dst_scalar_per_vector_k1),
'BBT_lds_add_extra_n' : str(operation.b_block_transfer.lds_add_extra_dim),
'CS_m_Xdl_per_wave_per_shuffle' : str(operation.cshuffle.m_Xdl_per_wave_per_shuffle),
'CS_n_Xdl_per_wave_per_shuffle' : str(operation.cshuffle.n_Xdl_per_wave_per_shuffle),
'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl' : operation.c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl,
'CTT_scalar_per_vector_n_wave_n_per_Xdl' : str(operation.c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl),
}
template = self.gemm_op_template
name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
+ "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# print(SubstituteTemplate(template, values))
instances = SubstituteTemplate(template, values)
print(instances)
# cf = open("instances.cpp",'w')
# cf.write(SubstituteTemplate(template, values))
# cf.close()
# cf = open("%s.cpp" % name,'w')
# cf.write(SubstituteTemplate(template, values))
# cf.close()
\ No newline at end of file
# take in input for gemm from user, send it to example template
# the structure for constructing this gemm op was taken from AIT's
# implementation of creating a gemm op
import enum
import ck_types
from copy import deepcopy
from dataclasses import dataclass
from enum import auto
from typing import List
from ck_types import *
class GemmType():
GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"
@dataclass
class TileDesc:
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
ak1: int
bk1: int
m_per_XDL: int
n_per_XDL: int
m_Xdl_per_wave: int
n_Xdl_per_wave: int
num_gemmk_prefetch_stage: int
def __str__(self) -> str:
values = list(self.__dict__.values())
@dataclass
class BlockTransferDesc:
thread_cluster_length: str
thread_cluster_arrange_order: str
src_access_order: str
src_vec_dim: int
src_scalar_per_vector: int
dst_scalar_per_vector_k1: int
lds_add_extra_dim: int
def __str__(self) -> str:
args = deepcopy(self.__dict__)
@dataclass
class CShuffleDesc:
m_Xdl_per_wave_per_shuffle: int
n_Xdl_per_wave_per_shuffle: int
def __str__(self) -> str:
args = deepcopy(self.__dict__)
@dataclass
class CBlockTransferDesc:
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl: str
scalar_per_vector_n_wave_n_per_Xdl: int
def __str__(self) -> str:
args = deepcopy(self.__dict__)
@dataclass
class GemmOperation:
A: TensorDesc
B: TensorDesc
acc: DataType
cs_type: DataType
Ds: TensorDesc
E: TensorDesc
a_elem_op: TensorOperation
b_elem_op: TensorOperation
cde_elem_op: TensorOperation
gemm_specialization: GemmType #GemmSpecialization
tile_desc: TileDesc
a_block_transfer: BlockTransferDesc
b_block_transfer: BlockTransferDesc
cshuffle: CShuffleDesc
b1_block_transfer: BlockTransferDesc = None
c_block_transfer: CBlockTransferDesc = None
def __str__(self) -> str:
io_name = "{gemm_kind}_{gemm_specialization}_{a_dtype}{b_dtype}{c_dtype}_{a_layout}{b_layout}{c_layout}".format(
#gemm_kind=library.GemmKindNames[self.operation_kind],
gemm_specialization=self.gemm_specialization.value,
a_dtype=[self.A.element],
b_dtype=[self.B.element],
a_layout=[self.A.layout],
b_layout=[self.B.layout],
)
DeviceGemmMultipleD_Xdl_CShuffle<
${layout_a},
${layout_b},
${layout_ds},
${layout_e},
${type_a},
${type_b},
${type_acc},
${type_cshuffle},
${type_ds},
${type_e},
${elementwise_op_a},
${elementwise_op_b},
${elementwise_op_cde},
${Gemm_spec},
${num_gemmk_prefetch_stage},
${block_size},
${mperblock},
${nperblock},
${kperblock},
${ak1},
${bk1},
${mperXDL},
${nperXDL},
${mXdlperwave},
${nXdlperwave},
${ABT_thread_cluster_lengths_K0_M_K1},
${ABT_thread_cluster_arrange_order},
${ABT_src_access_order},
${ABT_src_vec_dim},
${ABT_src_scalar_per_vec},
${ABT_dst_scalar_per_vec_k1},
${ABT_lds_add_extra_m},
${BBT_thread_cluster_lengths_K0_N_K1},
${BBT_thread_cluster_arrange_order},
${BBT_src_access_order},
${BBT_src_vec_dim},
${BBT_src_scalar_per_vec},
${BBT_dst_scalar_per_vec_k1},
${BBT_lds_add_extra_n},
${CS_m_Xdl_per_wave_per_shuffle},
${CS_n_Xdl_per_wave_per_shuffle},
${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
${CTT_scalar_per_vector_n_wave_n_per_Xdl}>;
# the structure for creating a list of instances for an op
# was taken from Meta's AIT library
import gemm_op as gemm
import enum
from dataclasses import dataclass
from enum import auto
import ck_types
from ck_types import *
def CreateGemmOperator():
#operation_kind = library.GemmKind.Gemm
a_element_desc = TensorDesc(
DataType.f16, Layout.ColumnMajor
)
b_element_desc = TensorDesc(
DataType.f16, Layout.RowMajor
)
ds_element_desc = TensorDesc(
DataType.f16_tuple,Layout.Row_Tuple
)
e_element_desc = TensorDesc(
DataType.f16,Layout.RowMajor
)
a_element_op = TensorOperation.PassThrough
b_element_op = TensorOperation.PassThrough
cde_element_op = TensorOperation.Bilinear
acc_type = DataType.f16
cshuffle_type = DataType.f32
tile_descriptions = [
gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1),
gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1),
gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1),
gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1),
gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1),
gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1, 1),
gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2, 1),
gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1),
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2, 1),
]
a_block_descriptions = [
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
]
b_block_descriptions = [
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
]
cshuffle_descriptions = [
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
]
c_block_descriptions = [
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 4>", 8),
]
#a_block_descriptions = b_block_descriptions
gemm_specialization = [
gemm.GemmType.GemmDefault
]
operations = []
for gemm_spec in gemm_specialization:
for tile_desc, a_block_desc, b_block_desc, cshuffle_desc, c_block_desc in zip(
tile_descriptions,
a_block_descriptions,
b_block_descriptions,
cshuffle_descriptions,
c_block_descriptions,
):
new_operation = gemm.GemmOperation(
#operation_kind=operation_kind,
A=a_element_desc,
B=b_element_desc,
acc = acc_type,
cs_type = cshuffle_type,
Ds=ds_element_desc,
E=e_element_desc,
a_elem_op = a_element_op,
b_elem_op=b_element_op,
cde_elem_op=cde_element_op,
gemm_specialization=gemm_spec,
tile_desc=tile_desc,
a_block_transfer=a_block_desc,
b_block_transfer=b_block_desc,
cshuffle = cshuffle_desc,
c_block_transfer=c_block_desc,
)
#manifest.append(new_operation)
operations.append(new_operation)
return operations
import os
import re
from hashlib import sha1
from typing import Any, Dict, OrderedDict
import jinja2
#from ...target import Target
#templating
FUNC_CALL_PARAM_TEMPLATE = jinja2.Template("(void *)({{name}})")
INSTANCE_TEMPLATE = jinja2.Template(
"""
using {{name}} = {{ config_name }};
"""
)
ARGS_PARSE_TEMPLATE = jinja2.Template(
"""
{% for idx in range(rank) %}
const int64_t in_{{idx}} = std::stoi(argv[{{ idx + 1 }}]);
{% endfor %}
"""
)
STRUCTS_DEF_TEMPLATE = jinja2.Template(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
FUNC_TEMPLATE = jinja2.Template(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
{{extra_headers}}
{{extra_code}}
{{instances_decl}}
{{func_signature}}
{
{{shape_eval}}
{{exec_paths}}
}
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names %}
{{indent}} const_cast<int64_t *>(&{{name}}),
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROFILER_TEMPLATE = jinja2.Template(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
"""
)
# rendering (messy, need to modularize and organize)
# def gen_profiler(
# shape_eval_template: jinja2.Template,
# exec_template: jinja2.Template,
# tensor_decl_template: jinja2.Template,
# extra_header_template: jinja2.Template,
# get_func_signature: Any,
# extra_code: str = "",
# func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE,
# indent: str = " ",
# ) -> str:
# shape_eval_template: jinja2.Template
# exec_template: jinja2.Template
# tensor_decl_template: jinja2.Template
#extra_header_template: jinja2.Template
get_func_signature: Any
extra_code: str = ""
func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE
indent: str = " "
# shape_eval = shape_eval_template.render(rank=2) #if shape_eval_template else ""
# exe_path = exec_template.render(instance="DeviceInstance",dtype="void",reduce_dims=1,rank=2,eps=eps,)
instances = INSTANCE_TEMPLATE.render(
name="DeviceInstance", config_name= "ck::tensor_operation::device::DeviceLayernormImpl",)
op_func = FUNC_TEMPLATE.render(
instances_decl=instances,
#func_signature=get_func_signature(func_attrs),
#shape_eval=shape_eval,
#exec_paths=exe_path,
#extra_headers=extra_header_template.render(),
extra_code=extra_code,)
structs_def = STRUCTS_DEF_TEMPLATE.render()
args_parse = ARGS_PARSE_TEMPLATE.render(rank=2)
#tensor_decl = tensor_decl_template.render(rank=2)
input_dim_names = [f"in_{i}" for i in range(2)]
func_call = func_call_template.render(
func_name="norm",
input="(void *) memory_pool->RequestHalfTensorByIdx(0)",
gamma="(void *) memory_pool->RequestHalfTensorByIdx(2)",
beta="(void *) memory_pool->RequestHalfTensorByIdx(3)",
output="(void *) memory_pool->RequestHalfTensorByIdx(1)",
input_dim_names=input_dim_names,
indent=indent,
)
code = PROFILER_TEMPLATE.render(
op_func=op_func,
structs_def=structs_def,
args_parse=args_parse,
#tensor_decl=tensor_decl,
func_call=func_call,
)
# print(instances)
# print(args_parse)
# print(structs_def)
#print(func_call)
#print(op_func)
print(code)
import jinja2
EXTRA_SHAPE_TEMPLATE = jinja2.Template(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
ck::index_t M0 = M / G1 / G2;
ck::index_t M1 = G1;
ck::index_t M2 = G2;
ck::index_t N0 = G3;
ck::index_t N1 = N / G3;
// GEMM shape
//ck::index_t M = M0 * M1 * M2;
//ck::index_t N = N0 * N1;
//ck::index_t K = 128;
//ck::index_t stride_A = K;
//ck::index_t stride_B = K;
// E = [M0, N0, M1, N1, M2]
/* 0, 3, 1, 4, 2
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
*/
// E = [M2, M0, N0, M1, N1] 2, 0, 3, 1, 4
ck::index_t stride_E_M0 = N0* M1* N1;
ck::index_t stride_E_M1 = N1;
ck::index_t stride_E_M2 = M0* N0* M1* N1;
ck::index_t stride_E_N0 = M1 * N1;
ck::index_t stride_E_N1 = 1;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
"""
)
output = EXTRA_SHAPE_TEMPLATE.render(indent=" ");
print (output)
\ No newline at end of file
CC = {{cc}}
CFLAGS = {{CFLAGS}}
fPIC_flag = {{fPIC}}
obj_files = {{obj_files}}
%.obj : %.{{cpp}}
{{cfile_cmd}}
%.obj : %.bin
{{bfile_cmd}}
.PHONY: all clean clean_constants
all: {{target}}
{{target}}: $(obj_files)
$(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files)
clean:
rm -f *.obj {{target}} test.so
clean_constants:
rm -f constants.bin
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment