Commit 516bbdcb authored by Astha Rai's avatar Astha Rai
Browse files

initial push with templated dev op

parent fde6d274
gemm: xx.o
CFLAGS=-I ~/rocm/composable_kernel/include -I /opt/rocm-5.1.1/hip/include -I ~/rocm/composable_kernel/include/ -I ~/rocm/composable_kernel/include/ck/ -I ~/rocm/composable_kernel/include/ck/problem_transform/ -I ~/rocm/composable_kernel/include/ck/tensor/ -I ~/rocm/composable_kernel/include/ck/tensor_description/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/rocm/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/rocm/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/rocm/composable_kernel/library/include/ck/library/host/ -I ~/rocm/composable_kernel/library/include/ck/library/host_tensor/ -I ~/rocm/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/rocm/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/rocm/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/rocm/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/rocm/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/rocm/composable_kernel/library/include/ck/library/tensor_op/ -I ~/rocm/composable_kernel/library/include/ck/library/utility/ -I ~/rocm/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
xx.o:
hipcc -c $(CFLAGS) $(CXXFLAGS) xx.cpp
CC = {{cc}}
CFLAGS = {{CFLAGS}}
fPIC_flag = {{fPIC}}
obj_files = {{obj_files}}
%.obj : %.{{cpp}}
{{cfile_cmd}}
%.obj : %.bin
{{bfile_cmd}}
.PHONY: all clean clean_constants
all: {{target}}
{{target}}: $(obj_files)
$(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files)
clean:
rm -f *.obj {{target}} test.so
clean_constants:
rm -f constants.bin
\ No newline at end of file
import jinja2
SHAPE_EVAL_TEMPLATE = jinja2.Template(
"""
int M = *in_{{ range(rank - 1)|join(' * *in_') }};
int N = *in_{{rank - 1}};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
"""
)
output = SHAPE_EVAL_TEMPLATE.render(rank=2);
print (output)
\ No newline at end of file
This diff is collapsed.
import enum
import os.path
import shutil
import functools
import operator
import collections
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_kernel_template = """
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_v1r3(const ${type_ab}* __restrict__ ${p_a_grid},
const ${type_ab}* __restrict__ ${p_b_grid},
${type_c}* __restrict__ ${p_c_grid},
const ${A_GridDesc_K0_M_K1} ${a_grid_desc_k0_m0_m1_k1},
const ${BGridDesc_K0_N_K1} ${b_grid_desc_k0_n0_n1_k1},
const ${CGridDesc_M0_M10_M11_N0_N10_N11} ${c_grid_desc_m0_m10_m11_n0_n10_n11},
const Block2CTileMap ${block_2_ctile_map})
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(${type_ab});
__shared__ ${type_ab} p_shared_block[shared_block_size];
GridwiseGemm::Run(${p_a_grid},
${p_b_grid},
${p_c_grid},
p_shared_block,
${a_grid_desc_k0_m0_m1_k1},
${b_grid_desc_k0_n0_n1_k1},
${c_grid_desc_m0_m10_m11_n0_n10_n11},
${block_2_ctile_map},
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
${type_ab},
${type_acc},
${type_c},
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
${A_GridDesc_K0_M_K1},
${BGridDesc_K0_N_K1},
${CGridDesc_M_N},
${mperblock},
${nperblock},
${k0perblock},
${k1value},
${M1PerThreadM111},
${N1PerThreadN111},
${KPerThread},
${M11N11ThreadClusterM110Xs},
${M11N11ThreadClusterN110Xs},
${ABlockTransferThreadSliceLengths_K0_M0_M1_K1},
${ABlockTransferThreadClusterLengths_K0_M0_M1_K1},
${ABlockTransferThreadClusterArrangeOrder},
${ABlockTransferSrcAccessOrder},
${ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1},
${ABlockTransferSrcVectorTensorContiguousDimOrder},
${ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1},
${BBlockTransferThreadSliceLengths_K0_N0_N1_K1},
${BBlockTransferThreadClusterLengths_K0_N0_N1_K1},
${BBlockTransferThreadClusterArrangeOrder},
${BBlockTransferSrcAccessOrder},
${BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1},
${BBlockTransferSrcVectorTensorContiguousDimOrder},
${BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1},
${CThreadTransferSrcDstAccessOrder},
${CThreadTransferSrcDstVectorDim},
${CThreadTransferDstScalarPerVector}>
struct GridwiseGemmDl_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
"""
def emit(self):
values = {
'function_name': "gemm",
'type_a' : 'ck::half_t',
'type_b' : 'ck::half_t',
'type_c' : 'ck::half_t',
'type_acc' : 'float',
'layout_a' : 'ck::tensor_layout::gemm::RowMajor',
'layout_b' : 'ck::tensor_layout::gemm::RowMajor',
'layout_c' : 'ck::tensor_layout::gemm::RowMajor',
'elementwise_op_a' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_b' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_c' : 'ck::tensor_operation::element_wise::PassThrough',
'Gemm_spec' : 'ck::tensor_operation::device::GemmSpecialization::MNKPadding',
'block_size' : '256',
'mperblock' : '64',
'nperblock' : '128',
'kperblock' : '32',
'k1' : '8',
'mperxdl' : '32',
'nperxdl' : '32',
'mxdlperwave' : '1',
'nxdlperwave' : '2',
'threadclusterlength_a' : 'ck::Sequence<4,64,1>',
'threadclusterarrange_a' : 'ck::Sequence<1,0,2>',
'srcaccessorder_a' : 'ck::Sequence<1,0,2>',
'srcvectordim_a' : '2',
'srcscalarpervec_a' : '8',
'dstscalarpervec_a' : '8',
'add_extra_dim_a' : '1',
'threadclusterlength_b' : 'ck::Sequence<8,32,1>',
'threadclusterarrange_b' : 'ck::Sequence<0,2,1>',
'srcaccessorder_b' : 'ck::Sequence<0,2,1>',
'srcvectordim_b' : '1',
'srcscalarpervec_b' : '4',
'dstscalarpervec_b' : '2',
'add_extra_dim_b' : '0',
'dstscalarpervec_c' : '8'
}
template = self.gemm_template
print(SubstituteTemplate(template, values))
\ No newline at end of file
import os
import re
from hashlib import sha1
from typing import Any, Dict, OrderedDict
import jinja2
#from ...target import Target
#templating
FUNC_CALL_PARAM_TEMPLATE = jinja2.Template("(void *)({{name}})")
INSTANCE_TEMPLATE = jinja2.Template(
"""
using {{name}} = {{ config_name }};
"""
)
ARGS_PARSE_TEMPLATE = jinja2.Template(
"""
{% for idx in range(rank) %}
const int64_t in_{{idx}} = std::stoi(argv[{{ idx + 1 }}]);
{% endfor %}
"""
)
STRUCTS_DEF_TEMPLATE = jinja2.Template(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
FUNC_TEMPLATE = jinja2.Template(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
{{extra_headers}}
{{extra_code}}
{{instances_decl}}
{{func_signature}}
{
{{shape_eval}}
{{exec_paths}}
}
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names %}
{{indent}} const_cast<int64_t *>(&{{name}}),
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROFILER_TEMPLATE = jinja2.Template(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
"""
)
# rendering (messy, need to modularize and organize)
# def gen_profiler(
# shape_eval_template: jinja2.Template,
# exec_template: jinja2.Template,
# tensor_decl_template: jinja2.Template,
# extra_header_template: jinja2.Template,
# get_func_signature: Any,
# extra_code: str = "",
# func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE,
# indent: str = " ",
# ) -> str:
# shape_eval_template: jinja2.Template
# exec_template: jinja2.Template
# tensor_decl_template: jinja2.Template
#extra_header_template: jinja2.Template
get_func_signature: Any
extra_code: str = ""
func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE
indent: str = " "
# shape_eval = shape_eval_template.render(rank=2) #if shape_eval_template else ""
# exe_path = exec_template.render(instance="DeviceInstance",dtype="void",reduce_dims=1,rank=2,eps=eps,)
instances = INSTANCE_TEMPLATE.render(
name="DeviceInstance", config_name= "ck::tensor_operation::device::DeviceLayernormImpl",)
op_func = FUNC_TEMPLATE.render(
instances_decl=instances,
#func_signature=get_func_signature(func_attrs),
#shape_eval=shape_eval,
#exec_paths=exe_path,
#extra_headers=extra_header_template.render(),
extra_code=extra_code,)
structs_def = STRUCTS_DEF_TEMPLATE.render()
args_parse = ARGS_PARSE_TEMPLATE.render(rank=2)
#tensor_decl = tensor_decl_template.render(rank=2)
input_dim_names = [f"in_{i}" for i in range(2)]
func_call = func_call_template.render(
func_name="norm",
input="(void *) memory_pool->RequestHalfTensorByIdx(0)",
gamma="(void *) memory_pool->RequestHalfTensorByIdx(2)",
beta="(void *) memory_pool->RequestHalfTensorByIdx(3)",
output="(void *) memory_pool->RequestHalfTensorByIdx(1)",
input_dim_names=input_dim_names,
indent=indent,
)
code = PROFILER_TEMPLATE.render(
op_func=op_func,
structs_def=structs_def,
args_parse=args_parse,
#tensor_decl=tensor_decl,
func_call=func_call,
)
# print(instances)
# print(args_parse)
# print(structs_def)
#print(func_call)
#print(op_func)
print(code)
import jinja2
EXTRA_SHAPE_TEMPLATE = jinja2.Template(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
ck::index_t M0 = M / G1 / G2;
ck::index_t M1 = G1;
ck::index_t M2 = G2;
ck::index_t N0 = G3;
ck::index_t N1 = N / G3;
// GEMM shape
//ck::index_t M = M0 * M1 * M2;
//ck::index_t N = N0 * N1;
//ck::index_t K = 128;
//ck::index_t stride_A = K;
//ck::index_t stride_B = K;
// E = [M0, N0, M1, N1, M2]
/* 0, 3, 1, 4, 2
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
*/
// E = [M2, M0, N0, M1, N1] 2, 0, 3, 1, 4
ck::index_t stride_E_M0 = N0* M1* N1;
ck::index_t stride_E_M1 = N1;
ck::index_t stride_E_M2 = M0* N0* M1* N1;
ck::index_t stride_E_N0 = M1 * N1;
ck::index_t stride_E_N1 = 1;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
"""
)
output = EXTRA_SHAPE_TEMPLATE.render(indent=" ");
print (output)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1,
256, // block_size
64, // m_per_block
128, // n_per_block
32, // k_per_block
8, // ak1
2, // bk1
32, // m_per_xdl
32, // n_per_xdl
1, // m_xdl_per_wave
2, // n_xdl_per_wave
ck::Sequence<4,64,1>, // thread_cluster_length
ck::Sequence<1,0,2>, // thread_cluster_arrange_order
ck::Sequence<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<8,32,1>, // thread_cluster_length
ck::Sequence<0,2,1>, // thread_cluster_arrange_order
ck::Sequence<0,2,1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
1, // n_xdl_per_wave
ck::Sequence<1,32,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector
>;
using fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0 = gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT;
void gemm_rrr_3(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
int64_t* a_dim0,
int64_t* a_dim1,
int64_t* b_dim0,
int64_t* b_dim1,
int64_t* c_dim0,
int64_t* c_dim1,
hipStream_t stream
) {
ck::index_t M = (*a_dim0);
ck::index_t N = (*b_dim1);
ck::index_t K = (*a_dim1);
int64_t offset_a = 0;
int64_t offset_b = 0;
int64_t offset_c = 0;
ck::index_t stride_a = *a_dim1;
ck::index_t stride_b = *b_dim1;
ck::index_t stride_c = *c_dim1;
if (M == 256 && N == 32 && K == 128) {
auto op = fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(
static_cast<ck::half_t *>(in_ptr) + offset_a,
static_cast<ck::half_t *>(weight_ptr) + offset_b,
static_cast<ck::half_t *>(out_ptr) + offset_c,
M,
N,
K,
stride_a,
stride_b,
stride_c,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}
);
if(!op.IsSupportedArgument(argument)) {
LOG(FATAL) << "wrong! " << op.GetTypeString() << " with the specified compilation parameters does not support this Gemm problem.";
}
invoker.Run(argument, StreamConfig{stream, false});
return;
}
LOG(FATAL) << "Unsupported workload for this gemm specialization.";
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment