Commit 5714d3c6 authored by Astha Rai's avatar Astha Rai
Browse files

added preliminary version of instance generator for MIGraphX - generates...

added preliminary version of instance generator for MIGraphX - generates string of instances for device_gemm_bilinear example
parent 475188ca
from dataclasses import dataclass
class DataType:
f16 = "F16"
f32 = "F32"
f16_tuple = "F16_Tuple"
class Layout:
ColumnMajor = "Col"
RowMajor = "Row"
Row_Tuple = "Row_Tuple"
class TensorOperation:
PassThrough = "PassThrough"
Bilinear = "Bilinear"
@dataclass
class TensorDesc: #set up and import properly
element: DataType
layout: Layout
import enum
import ck_types
from copy import deepcopy
from dataclasses import dataclass
from enum import auto
from typing import List
import os.path
import shutil
import functools
import operator
import collections
import subprocess
import re
import gemm_op
from gemm_op import *
import user
from ck_types import *
from gemm_ex import *
#from make_template import *
# holds multiple gemm instances
op_collection = user.CreateGemmOperator()
for op in op_collection:
x = EmitGemmInstance()
x.emit(op)
import enum
import os.path
import shutil
import functools
import operator
import collections
import subprocess
import re
import gemm_op
from gemm_op import *
import user
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_op_template = """
DeviceGemmMultipleD_Xdl_CShuffle<${layout_a}, ${layout_b}, ${layout_ds}, ${layout_e}, ${type_a}, ${type_b}, ${type_acc}, ${type_cshuffle}, ${type_ds}, ${type_e}, ${elementwise_op_a}, ${elementwise_op_b}, ${elementwise_op_cde}, ${Gemm_spec}, ${num_gemmk_prefetch_stage}, ${block_size}, ${mperblock}, ${nperblock}, ${kperblock}, ${ak1}, ${bk1}, ${mperXDL}, ${nperXDL}, ${mXdlperwave}, ${nXdlperwave}, ${ABT_thread_cluster_lengths_K0_M_K1}, ${ABT_thread_cluster_arrange_order}, ${ABT_src_access_order}, ${ABT_src_vec_dim}, ${ABT_src_scalar_per_vec}, ${ABT_dst_scalar_per_vec_k1}, ${ABT_lds_add_extra_m}, ${BBT_thread_cluster_lengths_K0_N_K1}, ${BBT_thread_cluster_arrange_order}, ${BBT_src_access_order}, ${BBT_src_vec_dim}, ${BBT_src_scalar_per_vec}, ${BBT_dst_scalar_per_vec_k1}, ${BBT_lds_add_extra_n}, ${CS_m_Xdl_per_wave_per_shuffle}, ${CS_n_Xdl_per_wave_per_shuffle}, ${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, ${CTT_scalar_per_vector_n_wave_n_per_Xdl}>,
"""
def emit(self,operation):
#name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block) + "_" + str(operation.tile_desc.ak1))
values = {
#'name' : name,
'layout_a' : operation.A.layout,
'layout_b' : operation.B.layout,
'layout_ds' : operation.Ds.layout,
'layout_e' : operation.E.layout,
'type_a' : operation.A.element,
'type_b' : operation.B.element,
'type_acc' : operation.acc,
'type_cshuffle' : operation.cs_type, #figure out how to arrange this
'type_ds' : operation.Ds.element,
'type_e' : operation.E.element,
'elementwise_op_a' : operation.a_elem_op,
'elementwise_op_b' : operation.b_elem_op,
'elementwise_op_cde' : operation.cde_elem_op,
'Gemm_spec' : operation.gemm_specialization,
'num_gemmk_prefetch_stage' : str(operation.tile_desc.num_gemmk_prefetch_stage),
'block_size' : str(operation.tile_desc.block_size),
'mperblock' : str(operation.tile_desc.m_per_block),
'nperblock' : str(operation.tile_desc.n_per_block),
'kperblock' : str(operation.tile_desc.k_per_block),
'ak1' : str(operation.tile_desc.ak1),
'bk1' : str(operation.tile_desc.bk1),
'mperXDL' : str(operation.tile_desc.m_per_XDL),
'nperXDL' : str(operation.tile_desc.n_per_XDL),
'mXdlperwave' : str(operation.tile_desc.m_Xdl_per_wave),
'nXdlperwave' : str(operation.tile_desc.n_Xdl_per_wave),
'ABT_thread_cluster_lengths_K0_M_K1' : operation.a_block_transfer.thread_cluster_length,
'ABT_thread_cluster_arrange_order' : operation.a_block_transfer.thread_cluster_arrange_order,
'ABT_src_access_order' : operation.a_block_transfer.src_access_order,
'ABT_src_vec_dim' : str(operation.a_block_transfer.src_vec_dim),
'ABT_src_scalar_per_vec' : str(operation.a_block_transfer.src_scalar_per_vector),
'ABT_dst_scalar_per_vec_k1' : str(operation.a_block_transfer.dst_scalar_per_vector_k1),
'ABT_lds_add_extra_m' : str(operation.a_block_transfer.lds_add_extra_dim),
'BBT_thread_cluster_lengths_K0_N_K1' : operation.b_block_transfer.thread_cluster_length,
'BBT_thread_cluster_arrange_order' : operation.b_block_transfer.thread_cluster_arrange_order,
'BBT_src_access_order' : operation.b_block_transfer.src_access_order,
'BBT_src_vec_dim' : str(operation.b_block_transfer.src_vec_dim),
'BBT_src_scalar_per_vec' : str(operation.b_block_transfer.src_scalar_per_vector),
'BBT_dst_scalar_per_vec_k1' : str(operation.b_block_transfer.dst_scalar_per_vector_k1),
'BBT_lds_add_extra_n' : str(operation.b_block_transfer.lds_add_extra_dim),
'CS_m_Xdl_per_wave_per_shuffle' : str(operation.cshuffle.m_Xdl_per_wave_per_shuffle),
'CS_n_Xdl_per_wave_per_shuffle' : str(operation.cshuffle.n_Xdl_per_wave_per_shuffle),
'CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl' : operation.c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl,
'CTT_scalar_per_vector_n_wave_n_per_Xdl' : str(operation.c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl),
}
template = self.gemm_op_template
name = (str(operation.tile_desc.block_size) + "_" + str(operation.tile_desc.m_per_block) + "_" + str(operation.tile_desc.n_per_block)
+ "_" + str(operation.tile_desc.k_per_block) + "_" + str(operation.tile_desc.ak1))
# print(SubstituteTemplate(template, values))
instances = SubstituteTemplate(template, values)
print(instances)
# cf = open("instances.cpp",'w')
# cf.write(SubstituteTemplate(template, values))
# cf.close()
# cf = open("%s.cpp" % name,'w')
# cf.write(SubstituteTemplate(template, values))
# cf.close()
\ No newline at end of file
# take in input for gemm from user, send it to example template
# the structure for constructing this gemm op was taken from AIT's
# implementation of creating a gemm op
import enum
import ck_types
from copy import deepcopy
from dataclasses import dataclass
from enum import auto
from typing import List
from ck_types import *
class GemmType():
GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"
@dataclass
class TileDesc:
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
ak1: int
bk1: int
m_per_XDL: int
n_per_XDL: int
m_Xdl_per_wave: int
n_Xdl_per_wave: int
num_gemmk_prefetch_stage: int
def __str__(self) -> str:
values = list(self.__dict__.values())
@dataclass
class BlockTransferDesc:
thread_cluster_length: str
thread_cluster_arrange_order: str
src_access_order: str
src_vec_dim: int
src_scalar_per_vector: int
dst_scalar_per_vector_k1: int
lds_add_extra_dim: int
def __str__(self) -> str:
args = deepcopy(self.__dict__)
@dataclass
class CShuffleDesc:
m_Xdl_per_wave_per_shuffle: int
n_Xdl_per_wave_per_shuffle: int
def __str__(self) -> str:
args = deepcopy(self.__dict__)
@dataclass
class CBlockTransferDesc:
cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl: str
scalar_per_vector_n_wave_n_per_Xdl: int
def __str__(self) -> str:
args = deepcopy(self.__dict__)
@dataclass
class GemmOperation:
A: TensorDesc
B: TensorDesc
acc: DataType
cs_type: DataType
Ds: TensorDesc
E: TensorDesc
a_elem_op: TensorOperation
b_elem_op: TensorOperation
cde_elem_op: TensorOperation
gemm_specialization: GemmType #GemmSpecialization
tile_desc: TileDesc
a_block_transfer: BlockTransferDesc
b_block_transfer: BlockTransferDesc
cshuffle: CShuffleDesc
b1_block_transfer: BlockTransferDesc = None
c_block_transfer: CBlockTransferDesc = None
def __str__(self) -> str:
io_name = "{gemm_kind}_{gemm_specialization}_{a_dtype}{b_dtype}{c_dtype}_{a_layout}{b_layout}{c_layout}".format(
#gemm_kind=library.GemmKindNames[self.operation_kind],
gemm_specialization=self.gemm_specialization.value,
a_dtype=[self.A.element],
b_dtype=[self.B.element],
a_layout=[self.A.layout],
b_layout=[self.B.layout],
)
DeviceGemmMultipleD_Xdl_CShuffle<
${layout_a},
${layout_b},
${layout_ds},
${layout_e},
${type_a},
${type_b},
${type_acc},
${type_cshuffle},
${type_ds},
${type_e},
${elementwise_op_a},
${elementwise_op_b},
${elementwise_op_cde},
${Gemm_spec},
${num_gemmk_prefetch_stage},
${block_size},
${mperblock},
${nperblock},
${kperblock},
${ak1},
${bk1},
${mperXDL},
${nperXDL},
${mXdlperwave},
${nXdlperwave},
${ABT_thread_cluster_lengths_K0_M_K1},
${ABT_thread_cluster_arrange_order},
${ABT_src_access_order},
${ABT_src_vec_dim},
${ABT_src_scalar_per_vec},
${ABT_dst_scalar_per_vec_k1},
${ABT_lds_add_extra_m},
${BBT_thread_cluster_lengths_K0_N_K1},
${BBT_thread_cluster_arrange_order},
${BBT_src_access_order},
${BBT_src_vec_dim},
${BBT_src_scalar_per_vec},
${BBT_dst_scalar_per_vec_k1},
${BBT_lds_add_extra_n},
${CS_m_Xdl_per_wave_per_shuffle},
${CS_n_Xdl_per_wave_per_shuffle},
${CTT_cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
${CTT_scalar_per_vector_n_wave_n_per_Xdl}>;
# the structure for creating a list of instances for an op
# was taken from Meta's AIT library
import gemm_op as gemm
import enum
from dataclasses import dataclass
from enum import auto
import ck_types
from ck_types import *
def CreateGemmOperator():
#operation_kind = library.GemmKind.Gemm
a_element_desc = TensorDesc(
DataType.f16, Layout.ColumnMajor
)
b_element_desc = TensorDesc(
DataType.f16, Layout.RowMajor
)
ds_element_desc = TensorDesc(
DataType.f16_tuple,Layout.Row_Tuple
)
e_element_desc = TensorDesc(
DataType.f16,Layout.RowMajor
)
a_element_op = TensorOperation.PassThrough
b_element_op = TensorOperation.PassThrough
cde_element_op = TensorOperation.Bilinear
acc_type = DataType.f16
cshuffle_type = DataType.f32
tile_descriptions = [
gemm.TileDesc(256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1),
gemm.TileDesc(256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1),
gemm.TileDesc(128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1),
gemm.TileDesc(256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1),
gemm.TileDesc(256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1),
gemm.TileDesc(256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1),
gemm.TileDesc(128, 128, 32, 32, 8, 8, 32, 32, 2, 1, 1),
gemm.TileDesc(128, 32, 128, 32, 8, 8, 32, 32, 1, 2, 1),
gemm.TileDesc(64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1),
gemm.TileDesc(64, 32, 64, 32, 8, 8, 32, 32, 1, 2, 1),
]
a_block_descriptions = [
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
]
b_block_descriptions = [
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 64, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 32, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
gemm.BlockTransferDesc("S<4, 16, 1>", "S<1, 0, 2>", "S<1, 0, 2>", 2, 8, 8, 1),
]
cshuffle_descriptions = [
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
gemm.CShuffleDesc(1,1),
]
c_block_descriptions = [
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 32, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 8>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 4>", 8),
gemm.CBlockTransferDesc("S<1, 16, 1, 4>", 8),
]
#a_block_descriptions = b_block_descriptions
gemm_specialization = [
gemm.GemmType.GemmDefault
]
operations = []
for gemm_spec in gemm_specialization:
for tile_desc, a_block_desc, b_block_desc, cshuffle_desc, c_block_desc in zip(
tile_descriptions,
a_block_descriptions,
b_block_descriptions,
cshuffle_descriptions,
c_block_descriptions,
):
new_operation = gemm.GemmOperation(
#operation_kind=operation_kind,
A=a_element_desc,
B=b_element_desc,
acc = acc_type,
cs_type = cshuffle_type,
Ds=ds_element_desc,
E=e_element_desc,
a_elem_op = a_element_op,
b_elem_op=b_element_op,
cde_elem_op=cde_element_op,
gemm_specialization=gemm_spec,
tile_desc=tile_desc,
a_block_transfer=a_block_desc,
b_block_transfer=b_block_desc,
cshuffle = cshuffle_desc,
c_block_transfer=c_block_desc,
)
#manifest.append(new_operation)
operations.append(new_operation)
return operations
import os
import re
from hashlib import sha1
from typing import Any, Dict, OrderedDict
import jinja2
#from ...target import Target
#templating
FUNC_CALL_PARAM_TEMPLATE = jinja2.Template("(void *)({{name}})")
INSTANCE_TEMPLATE = jinja2.Template(
"""
using {{name}} = {{ config_name }};
"""
)
ARGS_PARSE_TEMPLATE = jinja2.Template(
"""
{% for idx in range(rank) %}
const int64_t in_{{idx}} = std::stoi(argv[{{ idx + 1 }}]);
{% endfor %}
"""
)
STRUCTS_DEF_TEMPLATE = jinja2.Template(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
FUNC_TEMPLATE = jinja2.Template(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
{{extra_headers}}
{{extra_code}}
{{instances_decl}}
{{func_signature}}
{
{{shape_eval}}
{{exec_paths}}
}
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names %}
{{indent}} const_cast<int64_t *>(&{{name}}),
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROFILER_TEMPLATE = jinja2.Template(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
"""
)
# rendering (messy, need to modularize and organize)
# def gen_profiler(
# shape_eval_template: jinja2.Template,
# exec_template: jinja2.Template,
# tensor_decl_template: jinja2.Template,
# extra_header_template: jinja2.Template,
# get_func_signature: Any,
# extra_code: str = "",
# func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE,
# indent: str = " ",
# ) -> str:
# shape_eval_template: jinja2.Template
# exec_template: jinja2.Template
# tensor_decl_template: jinja2.Template
#extra_header_template: jinja2.Template
get_func_signature: Any
extra_code: str = ""
func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE
indent: str = " "
# shape_eval = shape_eval_template.render(rank=2) #if shape_eval_template else ""
# exe_path = exec_template.render(instance="DeviceInstance",dtype="void",reduce_dims=1,rank=2,eps=eps,)
instances = INSTANCE_TEMPLATE.render(
name="DeviceInstance", config_name= "ck::tensor_operation::device::DeviceLayernormImpl",)
op_func = FUNC_TEMPLATE.render(
instances_decl=instances,
#func_signature=get_func_signature(func_attrs),
#shape_eval=shape_eval,
#exec_paths=exe_path,
#extra_headers=extra_header_template.render(),
extra_code=extra_code,)
structs_def = STRUCTS_DEF_TEMPLATE.render()
args_parse = ARGS_PARSE_TEMPLATE.render(rank=2)
#tensor_decl = tensor_decl_template.render(rank=2)
input_dim_names = [f"in_{i}" for i in range(2)]
func_call = func_call_template.render(
func_name="norm",
input="(void *) memory_pool->RequestHalfTensorByIdx(0)",
gamma="(void *) memory_pool->RequestHalfTensorByIdx(2)",
beta="(void *) memory_pool->RequestHalfTensorByIdx(3)",
output="(void *) memory_pool->RequestHalfTensorByIdx(1)",
input_dim_names=input_dim_names,
indent=indent,
)
code = PROFILER_TEMPLATE.render(
op_func=op_func,
structs_def=structs_def,
args_parse=args_parse,
#tensor_decl=tensor_decl,
func_call=func_call,
)
# print(instances)
# print(args_parse)
# print(structs_def)
#print(func_call)
#print(op_func)
print(code)
import jinja2
EXTRA_SHAPE_TEMPLATE = jinja2.Template(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
ck::index_t M0 = M / G1 / G2;
ck::index_t M1 = G1;
ck::index_t M2 = G2;
ck::index_t N0 = G3;
ck::index_t N1 = N / G3;
// GEMM shape
//ck::index_t M = M0 * M1 * M2;
//ck::index_t N = N0 * N1;
//ck::index_t K = 128;
//ck::index_t stride_A = K;
//ck::index_t stride_B = K;
// E = [M0, N0, M1, N1, M2]
/* 0, 3, 1, 4, 2
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
*/
// E = [M2, M0, N0, M1, N1] 2, 0, 3, 1, 4
ck::index_t stride_E_M0 = N0* M1* N1;
ck::index_t stride_E_M1 = N1;
ck::index_t stride_E_M2 = M0* N0* M1* N1;
ck::index_t stride_E_N0 = M1 * N1;
ck::index_t stride_E_N1 = 1;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
"""
)
output = EXTRA_SHAPE_TEMPLATE.render(indent=" ");
print (output)
\ No newline at end of file
CC = {{cc}}
CFLAGS = {{CFLAGS}}
fPIC_flag = {{fPIC}}
obj_files = {{obj_files}}
%.obj : %.{{cpp}}
{{cfile_cmd}}
%.obj : %.bin
{{bfile_cmd}}
.PHONY: all clean clean_constants
all: {{target}}
{{target}}: $(obj_files)
$(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files)
clean:
rm -f *.obj {{target}} test.so
clean_constants:
rm -f constants.bin
\ No newline at end of file
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ck::half_t,
typename ck::half_t,
typename ck::half_t,
typename float,
typename ck::tensor_layout::gemm::ColMajor,
typename ck::tensor_layout::gemm::RowMajor,
typename ck::tensor_layout::gemm::RowMajor,
typename ck::tensor_operation::element_wise::PassThrough,
typename ck::tensor_operation::element_wise::PassThrough,
typename ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::Default,
256,
128,
128,
16,
2,
4,
4,
1,
typename S<8, 2>,
typename S<8, 2>,
typename S<2, 1, 4, 2>,
typename S<8, 1, 32, 1>,
typename S<0, 3, 1, 2>,
typename S<0, 3, 1, 2>,
typename S<1, 1, 4, 1>,
typename S<0, 3, 1, 2>,
typename S<1, 1, 4, 2>,
typename S<2, 1, 4, 2>,
typename S<8, 1, 32, 1>,
typename S<0, 3, 1, 2>,
typename S<0, 3, 1, 2>,
typename S<1, 1, 4, 1>,
typename S<0, 3, 1, 2>,
typename S<1, 1, 4, 2>,
typename S<0, 1, 2, 3, 4, 5>,
5,
4>
struct DeviceGemmDl : public DeviceGemm<ck::tensor_layout::gemm::ColMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<2>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % 2 == 0);
const index_t K0 = K / 2;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColMajor>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColMajor>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(ck::tensor_operation::device::GemmSpecialization::Default == GemmSpecialization::MNPadding)
{
const auto PadM = (128 - M % 128) % 128;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % 2 == 0);
const index_t K0 = K / 2;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ${layout_B}>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(ck::tensor_operation::device::GemmSpecialization::Default == GemmSpecialization::MNPadding)
{
const auto PadN = (128 - N % 128) % 128;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(ck::tensor_operation::device::GemmSpecialization::Default == GemmSpecialization::MNPadding)
{
const auto PadM = (128 - M % 128) % 128;
const auto PadN = (128 - N % 128) % 128;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ck::half_t,
float,
ck::half_t,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
128,
128,
16,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const ck::half_t* p_a_grid,
const ck::half_t* p_b_grid,
ck::half_t* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
ck::tensor_operation::element_wise::PassThrough a_element_op,
ck::tensor_operation::element_wise::PassThrough b_element_op,
ck::tensor_operation::element_wise::PassThrough c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
// private:
const ck::half_t* p_a_grid_;
const ck::half_t* p_b_grid_;
ck::half_t* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused, but may be useful in future.
index_t M01_;
index_t N01_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
ck::tensor_operation::element_wise::PassThrough a_element_op_;
ck::tensor_operation::element_wise::PassThrough b_element_op_;
ck::tensor_operation::element_wise::PassThrough c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmDl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ck::half_t,
ck::half_t,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(256),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ck::half_t,
ck::half_t,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(256),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ck::half_t,
ck::half_t,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(256),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ck::half_t,
ck::half_t,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(256),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ck::half_t* p_a,
const ck::half_t* p_b,
ck::half_t* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
ck::tensor_operation::element_wise::PassThrough a_element_op,
ck::tensor_operation::element_wise::PassThrough b_element_op,
ck::tensor_operation::element_wise::PassThrough c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
ck::tensor_operation::element_wise::PassThrough a_element_op,
ck::tensor_operation::element_wise::PassThrough b_element_op,
ck::tensor_operation::element_wise::PassThrough c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ck::half_t*>(p_a),
static_cast<const ck::half_t*>(p_b),
static_cast<ck::half_t*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDl"
<< "<"
<< 256 << ", "
<< 128 << ", "
<< 128 << ", "
<< 16 << ", "
<< 2 << ", "
<< 4 << ", "
<< 4 << ", "
<< 1
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment