Commit 47cc9b7e authored by Astha Rai's avatar Astha Rai
Browse files

added compilation of shared library and multiple instances for gemm, cleaned up code design

parent adbefd90
# CK Python API
This API uses Python to generate instances of operations present in CK, compiles them into a shared library, and an executable to run the instances.
There are 2 directories: shared and normal. The normal directory contains one instance that will compile into an excutable to be run, while the shared directory
generates multiple instances and compiles them into a shared library.
## Normal
## Shared
gemm: xx.o
CFLAGS=-I ~/workspace/composable_kernel/include -I /opt/workspace/rocm-5.1.1/hip/include -I ~/workspace/composable_kernel/include/ -I ~/workspace/composable_kernel/include/ck/ -I ~/workspace/composable_kernel/include/ck/problem_transform/ -I ~/workspace/composable_kernel/include/ck/tensor/ -I ~/workspace/composable_kernel/include/ck/tensor_description/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/workspace/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/workspace/composable_kernel/library/include/ck/library/host/ -I ~/workspace/composable_kernel/library/include/ck/library/host_tensor/ -I ~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_op/ -I ~/workspace/composable_kernel/library/include/ck/library/utility/ -I ~/workspace/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
xx.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c xx.cpp
CC = {{cc}}
CFLAGS = {{CFLAGS}}
fPIC_flag = {{fPIC}}
obj_files = {{obj_files}}
%.obj : %.{{cpp}}
{{cfile_cmd}}
%.obj : %.bin
{{bfile_cmd}}
.PHONY: all clean clean_constants
all: {{target}}
{{target}}: $(obj_files)
$(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files)
clean:
rm -f *.obj {{target}} test.so
clean_constants:
rm -f constants.bin
\ No newline at end of file
import enum
import os.path
import shutil
import functools
import operator
import collections
import subprocess
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.make_template = """
CFLAGS=-I ~/workspace/composable_kernel/include -I /opt/workspace/rocm-5.1.1/hip/include -I ~/workspace/composable_kernel/include/ -I ~/workspace/composable_kernel/include/ck/ -I ~/workspace/composable_kernel/example/01_gemm/ -I ~/workspace/composable_kernel/library/include/ -I ~/workspace/composable_kernel/library/src/utility/ -I ~/workspace/composable_kernel/include/ck/problem_transform/ -I ~/workspace/composable_kernel/include/ck/tensor/ -I ~/workspace/composable_kernel/include/ck/tensor_description/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/workspace/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/workspace/composable_kernel/library/include/ck/library/host/ -I ~/workspace/composable_kernel/library/include/ck/library/host_tensor/ -I ~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_op/ -I ~/workspace/composable_kernel/library/include/ck/library/utility/ -I ~/workspace/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
gemm: ex.o host_tensor.o device_memory.o
hipcc $(CXXFLAGS) $(CFLAGS) ex.o host_tensor.o device_memory.o -o gemm
device_memory.o: ../../../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/device_memory.cpp
host_tensor.o: ../../../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/host_tensor.cpp
ex.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
"""
self.gemm_devop_template = """
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
${type_a},
${type_b},
${type_c},
${type_acc},
${layout_a},
${layout_b},
${layout_c},
${elementwise_op_a},
${elementwise_op_b},
${elementwise_op_c},
${Gemm_spec},
${block_size},
${mperblock},
${nperblock},
${k0perblock},
${k1},
${m1perthread},
${n1perthread},
${kperthread},
${m1n1_thcluster_m1xs},
${m1n1_thcluster_n1xs},
${ABT_thread_slice_lengths_K0_M0_M1_K1},
${ABT_thread_cluster_lengths_K0_M0_M1_K1},
${ABT_thread_cluster_arrange_order},
${ABT_src_access_order},
${ABT_src_vec_tensor_lengths_K0_M0_M1_K1},
${ABT_src_vec_tensor_cont_dim_order},
${ABT_dst_vec_tensor_lengths_K0_M0_M1_K1},
${BBT_thread_slice_lengths_K0_N0_N1_K1},
${BBT_thread_cluster_lengths_K0_N0_N1_K1},
${BBT_thread_cluster_arrange_order},
${BBT_src_access_order},
${BBT_src_vec_tensor_lengths_K0_N0_N1_K1},
${BBT_src_vec_tensor_cont_dim_order},
${BBT_dst_vec_tensor_lengths_K0_N0_N1_K1},
${CTT_src_dst_access_order},
${CTT_src_dst_vec_dim},
${CTT_dst_scalar_per_vector}>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<${type_a}> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ${layout_a}{}));
Tensor<${type_b}> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ${layout_b}{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<${type_a}>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<${type_b}>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<${type_a}>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<${type_b}>{-1.f, 1.f}(b_k_n);
}
Tensor<${type_c}> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<${type_c}> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(${type_a}) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(${type_b}) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(${type_c}) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ${elementwise_op_a}{};
auto b_element_op = ${elementwise_op_b}{};
auto c_element_op = ${elementwise_op_c}{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<${type_a}*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<${type_b}*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<${type_c}*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(${type_a}) * M * K + sizeof(${type_b}) * K * N + sizeof(${type_c}) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
"""
def emit(self):
values = {
'type_a' : 'ck::half_t',
'type_b' : 'ck::half_t',
'type_c' : 'ck::half_t',
'type_acc' : 'float',
'layout_a' : 'ck::tensor_layout::gemm::ColumnMajor',
'layout_b' : 'ck::tensor_layout::gemm::RowMajor',
'layout_c' : 'ck::tensor_layout::gemm::RowMajor',
'elementwise_op_a' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_b' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_c' : 'ck::tensor_operation::element_wise::PassThrough',
'Gemm_spec' : 'ck::tensor_operation::device::GemmSpecialization::Default',
'block_size' : '256',
'mperblock' : '128',
'nperblock' : '128',
'k0perblock' : '16',
'k1' : '2',
'm1perthread' : '4',
'n1perthread' : '4',
'kperthread' : '1',
'm1n1_thcluster_m1xs' : 'S<8, 2>',
'm1n1_thcluster_n1xs' : 'S<8, 2>',
'ABT_thread_slice_lengths_K0_M0_M1_K1' : 'S<2, 1, 4, 2>',
'ABT_thread_cluster_lengths_K0_M0_M1_K1' : 'S<8, 1, 32, 1>',
'ABT_thread_cluster_arrange_order' : 'S<0, 3, 1, 2>',
'ABT_src_access_order' : 'S<0, 3, 1, 2>',
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1' : 'S<1, 1, 4, 1>',
'ABT_src_vec_tensor_cont_dim_order' : 'S<0, 3, 1, 2>',
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1' : 'S<1, 1, 4, 2>',
'BBT_thread_slice_lengths_K0_N0_N1_K1' : 'S<2, 1, 4, 2>',
'BBT_thread_cluster_lengths_K0_N0_N1_K1' : 'S<8, 1, 32, 1>',
'BBT_thread_cluster_arrange_order' : 'S<0, 3, 1, 2>',
'BBT_src_access_order' : 'S<0, 3, 1, 2>',
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1' : 'S<1, 1, 4, 1>',
'BBT_src_vec_tensor_cont_dim_order' : 'S<0, 3, 1, 2>',
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1': 'S<1, 1, 4, 2>',
'CTT_src_dst_access_order' : 'S<0, 1, 2, 3, 4, 5>',
'CTT_src_dst_vec_dim' : '5',
'CTT_dst_scalar_per_vector' : '4'
}
template = self.gemm_devop_template
cf = open("ex.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
m_template = self.make_template
cf = open("Makefile", 'w')
print(SubstituteTemplate(m_template, values))
cf.write(SubstituteTemplate(m_template, values))
cf.close()
PIPE = -1
STDOUT = -2
proc = subprocess.Popen(
["make"],
shell=True,
env=os.environ.copy(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = proc.communicate()
a = EmitGemmInstance()
a.emit()
import enum
import os.path
import shutil
import functools
import operator
import collections
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_devop_template = """
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256,
128,
128,
16,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
"""
def emit(self):
values = {
'type_a' : 'ck::half_t',
}
template = self.gemm_devop_template
cf = open("xx.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
a = EmitGemmInstance()
a.emit()
CC = /opt/rocm/bin/hipcc
CK_PATH=/dockerx/composable_kernel/
CFLAGS = -O3 -std=c++17 -DCK_AMD_GPU_GFX90A --offload-arch=gfx90a -I"${CK_PATH}/include" -I"${CK_PATH}/library/include" -I"${CK_PATH}/profiler/include"
OBJS = ex.o host_tensor.o device_memory.o
all: $(OBJS)
$(CC) $(CFLAGS) $(OBJS) -o ex
device_memory.o: ../../library/src/utility/device_memory.cpp
$(CC) $(CFLAGS) -c ../../library/src/utility/device_memory.cpp
host_tensor.o: ../../library/src/utility/host_tensor.cpp
$(CC) $(CFLAGS) -c ../../library/src/utility/host_tensor.cpp
\ No newline at end of file
import enum
import os.path
import shutil
import functools
import operator
import collections
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_devop_template = """
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256,
128,
128,
16,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
"""
def emit(self):
values = {
'type_a' : 'ck::half_t',
}
template = self.gemm_devop_template
cf = open("xx.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
a = EmitGemmInstance()
a.emit()
#take in input for gemm from user, send it to example template
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDl : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused, but may be useful in future.
index_t M01_;
index_t N01_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmDl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
typename M11N11ThreadClusterM110Xs,
typename M11N11ThreadClusterN110Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemmDl_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_grid_desc_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_grid_desc_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0),
a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0),
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Backend-agnostic functions for elementwise codegen.
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import jinja2
# from aitemplate.backend.backend_spec import BackendSpec
# from ...compiler.base import IntImm, IntVar, Operator, Tensor
# from ...compiler.tensor_accessor import TensorAccessor
# from ...utils import shape_utils
#from . import tensor_accessor_codegen
CONSTANT_TEMPLATE = jinja2.Template(
"""
#define FUSED_ELE_THREAD_SIZE 256
const int N_ELEMENTS_PER_THREAD = sizeof({{read_t}}) / sizeof({{data_t}});
const int N_ELEMENTS_PER_READ = sizeof({{read_t}}) / sizeof({{data_t}});
const int N_OPS_PER_THREAD = sizeof({{read_t}}) / sizeof({{op_t}});
"""
)
KERNEL_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const {{read_t}}* input{{idx}}")
KERNEL_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("{{read_t}}* output{{idx}}")
KERNEL_TMP_INPUT_TEMPLATE = jinja2.Template("p_tmp_i{{idx}}[i]")
KERNEL_TMP_OUTPUT_TEMPLATE = jinja2.Template("p_tmp_o{{idx}}[i]")
GET_STRIDED_ADDRESS_TEMPLATE = jinja2.Template(
"""
{% if tensor_accessor.is_contiguous %}
{{data_ptr}} = get_strided_address</*data_t*/ {{data_t}},
/*read_t*/ {{read_t}},
/*is_contiguous*/ true>(
{{data_ptr}}, {{data_idx}}, {{tensor_accessor.offset}}, 0, 0);
{% else %}
{{data_ptr}} = get_strided_address</*data_t*/ {{data_t}},
/*read_t*/ {{read_t}},
/*is_contiguous*/ false>(
{{data_ptr}}, {{data_idx}},
{{tensor_accessor.offset}},
{{tensor_accessor.original_total_elements_from_stride_dim}},
{{tensor_accessor.actual_total_elements_from_stride_dim}});
{% endif %}
"""
)
KERNEL_READ_INPUT_TEMPLATE = jinja2.Template(
"""
{{read_t}} *{{input_name}} = const_cast<{{read_t}}*>(input{{input_idx}});
{{get_strided_address}}
{{read_t}} tmp_i{{input_idx}} = *{{input_name}};
const {{op_t}}* p_tmp_i{{input_idx}} = reinterpret_cast<const {{op_t}}*>(&tmp_i{{input_idx}});
"""
)
KERNEL_DEFINE_OUTPUTS_TEMPLATE = jinja2.Template(
"""
{% for idx in indexes %}
{{read_t}} tmp_o{{idx}};
{{op_t}}* p_tmp_o{{idx}} = reinterpret_cast<{{op_t}}*>(&tmp_o{{idx}});
{% endfor %}
"""
)
KERNEL_WRITE_OUTPUT_TEMPLATE = jinja2.Template(
"""
{{get_strided_address}}
*{{output_name}} = tmp_o{{output_idx}};
"""
)
KERNEL_TEMPLATE = jinja2.Template(
"""
__global__ void
{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const {{index_type}} idx = bid * FUSED_ELE_THREAD_SIZE + tid;
const {{index_type}} idx_elem = idx * N_ELEMENTS_PER_THREAD;
if (idx_elem >= n_elements) {
return;
}
{{read_inputs}}
{{define_outputs}}
#pragma unroll
for (int i = 0; i < N_OPS_PER_THREAD; ++i) {
{{fused_funcs}}
}
{{write_outputs}}
}
"""
)
FUNC_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const void* input{{idx}}")
FUNC_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("void* output{{idx}}")
KERNEL_CALL_INPUT_PARAM_TEMPLATE = jinja2.Template(
"reinterpret_cast<const {{read_t}}*>(input{{idx}})"
)
KERNEL_CALL_OUTPUT_PARAM_TEMPLATE = jinja2.Template(
"reinterpret_cast<{{read_t}}*>(output{{idx}})"
)
FUNC_TEMPLATE = jinja2.Template(
"""
{{head}}
namespace {
{{constant}}
{{custom_libs}}
{{tensor_accessor_lib}}
{{kernel_function}}
} // namespace
void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims_decl}} {{index_type}} n_elements, {{prefix}}Stream_t stream) {
if (n_elements == 0) {
return;
}
int block_size = static_cast<int>(std::ceil(static_cast<double>(n_elements) / N_ELEMENTS_PER_THREAD / FUSED_ELE_THREAD_SIZE));
{{func_name}}<<<block_size, FUSED_ELE_THREAD_SIZE, 0, stream>>>(
{{kernel_call_output_params}},
{{kernel_call_input_params}},
{{dynamic_dims_call}}
n_elements
);
}
"""
)
FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements, {{prefix}}Stream_t stream);
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{
{{indent}}{{index_type}} {{func_name}}_n_elements = {{calculate_n}};
{{indent}}invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{func_name}}_n_elements, {{stream}});
{{indent}}}
"""
)
@dataclass
class ElementwiseMetaData:
func_name: str
op_t: str
args: List[Tensor]
outputs: List[Tensor]
@dataclass
class FusedElementwiseMetaData:
# Input / output Tensors and TensorAccessors.
inputs: List[Tensor]
outputs: List[Tensor]
input_accessors: List[TensorAccessor]
output_accessors: List[TensorAccessor]
# Original input / output Tensors before graph transformation.
# Kept here for elementwise -> fused elementwise Tensor mapping.
original_inputs: List[Tensor]
original_outputs: List[Tensor]
read_t: str
op_t: str
data_t: str
input_broadcast_sizes: List[List[IntVar]]
dynamic_dims: List[IntVar]
sub_funcs: List[ElementwiseMetaData]
def gen_function_single_thread(
fused_func_metadata,
input_names,
output_names,
type_converter,
) -> str:
"""Per thread elementwise function codegen."""
tensor_to_expr: Dict[Tensor, str] = {}
body = ""
for tensor, name in zip(fused_func_metadata.original_inputs, input_names):
tensor_to_expr[tensor] = name
tmp_output_idx: int = 0
for func_metadata in fused_func_metadata.sub_funcs:
params: List[str] = []
func_op_t = func_metadata.op_t
input_converter = None
output_converter = None
if func_op_t != fused_func_metadata.op_t:
input_converter = type_converter.get(fused_func_metadata.op_t).get(
func_op_t
)
output_converter = type_converter.get(func_op_t).get(
fused_func_metadata.op_t
)
assert (
input_converter is not None
), "Unsupported convertion from {} to {}".format(
fused_func_metadata.op_t, func_op_t
)
assert (
output_converter is not None
), "Unsupported convertion from {} to {}".format(
func_op_t, fused_func_metadata.op_t
)
for arg in func_metadata.args:
if arg in tensor_to_expr:
param = tensor_to_expr[arg]
params.append(
"{}({})".format(input_converter, param)
if input_converter is not None
else param
)
elif arg.is_a_const_num():
if func_op_t[-1] == "2":
params.append(
"{}({},{})".format(
func_op_t,
str(arg._attrs["value"]),
str(arg._attrs["value"]),
)
)
else:
params.append("{}({})".format(func_op_t, str(arg._attrs["value"])))
else:
raise RuntimeError(
"Cannot generate expression for node {}, ops: {}".format(
arg, func_metadata
)
)
assert (
len(func_metadata.outputs) == 1
), "Operator has more than 1 output! Operator: {}".format(func_metadata)
output = func_metadata.outputs[0]
func_def = "{}({})".format(func_metadata.func_name, ",".join(params))
func_def = (
"{}({})".format(output_converter, func_def)
if output_converter is not None
else func_def
)
if len(output._attrs["dst_ops"]) > 1:
name = "tmp_" + (str)(tmp_output_idx)
tmp_output_idx += 1
body += "{} {} = {};\n".format(fused_func_metadata.op_t, name, func_def)
tensor_to_expr[output] = name
else:
tensor_to_expr[output] = func_def
for tensor, name in zip(fused_func_metadata.original_outputs, output_names):
if tensor not in tensor_to_expr:
raise RuntimeError(
"Cannot generate expression for node {}, outputs: {}".format(
tensor, fused_func_metadata.original_outputs
)
)
expr = tensor_to_expr[tensor]
body += "{} = {};\n".format(name, expr)
return body
def _get_sub_func_metadata(
ops: List[Operator], data_t: str, op_t: str, backend_spec: BackendSpec
) -> Tuple[List[ElementwiseMetaData], str]:
candidate_op_types = backend_spec.get_candidate_op_types(op_t)
func_enums = []
for op in ops:
func_enum = op._attrs["func"]
func_enums.append(func_enum)
funcs = backend_spec.func_enum_to_func_name.get(func_enum)
if funcs is None:
raise NotImplementedError("Func {} is not supported!".format(func_enum))
for candidate_op_t in candidate_op_types:
func_name = funcs.get(candidate_op_t)
if func_name is not None:
candidate_op_types = backend_spec.get_candidate_op_types(candidate_op_t)
break
if len(candidate_op_types) == 0:
raise RuntimeError(
"Cannot find a common rocm data type! candidate_op_types: {}, op_t: {}.".format(
candidate_op_types, op_t
)
)
if op_t in set(candidate_op_types):
op_t = candidate_op_types[0]
else:
op_t = data_t
candidate_op_types = backend_spec.get_candidate_op_types(op_t)
sub_func_metadata = []
for op in ops:
func_enum = op._attrs["func"]
funcs = backend_spec.func_enum_to_func_name.get(func_enum)
func_name = None
func_op_t = None
for candidate_op_t in candidate_op_types:
func_name = funcs.get(candidate_op_t)
if func_name is not None:
func_op_t = candidate_op_t
break
if func_name is None:
raise NotImplementedError(
"Unsupported func {} and op type {}!".format(func_enum, op_t)
)
sub_func_metadata.append(
ElementwiseMetaData(
func_name, func_op_t, op._attrs["args"], op._attrs["outputs"]
)
)
return (sub_func_metadata, op_t)
def _get_types_and_sizes(
inputs: List[Tensor],
input_accessors: List[TensorAccessor],
output_accessors: List[TensorAccessor],
backend_spec: BackendSpec,
) -> Tuple[int, List[List[IntVar]], str]:
"""
Returns Tuple(alignment, input_broadcast_sizes, dtype)
"""
# Handle input broadcast.
output_shape = output_accessors[0].original_shapes
dtype = inputs[0]._attrs["dtype"]
input_broadcast_sizes = []
min_num_elements = None
for input_accessor in input_accessors:
input_shape = input_accessor.original_shapes
broadcastable, _ = shape_utils.get_broadcast_max_shape(
output_shape, input_shape
)
if not broadcastable:
raise RuntimeError(
"Input shape {} is not compatible with output shape {}!".format(
input_shape, output_shape
)
)
num_rightmost_non_broadcast_elements = len(input_shape)
extended_input_shape = list(input_shape)
if input_shape == output_shape:
input_broadcast_sizes.append(None)
else:
extended_input_shape = [IntImm(1)] * len(output_shape)
extended_input_shape[len(output_shape) - len(input_shape) :] = input_shape
input_broadcast_sizes.append(extended_input_shape)
for i in reversed(range(len(extended_input_shape))):
if extended_input_shape[i] != output_shape[i]:
num_rightmost_non_broadcast_elements -= i + 1
break
num_elements_for_alignments = shape_utils.get_num_rightmost_static_elements(
extended_input_shape, num_rightmost_non_broadcast_elements
)
if not min_num_elements:
min_num_elements = num_elements_for_alignments
else:
min_num_elements = min(min_num_elements, num_elements_for_alignments)
alignment = tensor_accessor_codegen.find_max_alignment(
min_num_elements, output_accessors
)
# Note that we use the same alignment for accessing inputs and outputs, although
# they may have different alignment requirements. We may lose perf a little bit,
# but reduce the complexity of our jinja template. We can do some perf
# experiments later to determine if we want to chase more perf gains.
alignment = tensor_accessor_codegen.find_max_alignment(alignment, input_accessors)
return alignment, input_broadcast_sizes, dtype
def _get_dynamic_dims(output_accessors: List[TensorAccessor]) -> List[IntVar]:
res = {}
for output_accessor in output_accessors:
for dim in output_accessor.original_shapes:
if not isinstance(dim, IntImm):
res[dim._attrs["name"]] = dim
return res.values()
def _parse_func_metadata(
ops: List[Operator],
inputs: List[Tensor],
outputs: List[Tensor],
input_accessors: List[TensorAccessor],
output_accessors: List[TensorAccessor],
original_inputs: List[Tensor],
original_outputs: List[Tensor],
backend_spec: BackendSpec,
) -> FusedElementwiseMetaData:
alignment, input_broadcast_sizes, dtype = _get_types_and_sizes(
inputs, input_accessors, output_accessors, backend_spec
)
read_type = backend_spec.get_backend_type(
alignment, dtype, backend_spec.read_num_elements_to_backend_type
)
op_type = backend_spec.get_backend_type(
alignment, dtype, backend_spec.op_num_elements_to_backend_type
)
data_type = backend_spec.dtype_to_backend_type(dtype)
sub_func_metadata, op_type = _get_sub_func_metadata(
ops, data_type, op_type, backend_spec
)
dynamic_dims = _get_dynamic_dims(output_accessors)
return FusedElementwiseMetaData(
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
read_type,
op_type,
data_type,
input_broadcast_sizes,
dynamic_dims,
sub_func_metadata,
)
def _gen_int_var_product_str(
int_vars: List[IntVar],
) -> str:
res = []
for int_var in int_vars:
if isinstance(int_var, IntImm):
res.append(str(int_var._attrs["values"][0]))
elif isinstance(int_var, IntVar):
res.append(int_var._attrs["name"])
else:
raise RuntimeError(
"A dim must be an IntVar! Current type: {}".format(type(int_var))
)
return " * ".join(res)
def _gen_input_broadcast_calculator_str(
input_shape: List[IntVar],
output_shape: List[IntVar],
) -> str:
output_num_elements = []
output_strides = []
input_strides = []
start_idx = 0
for i, (input_dim, output_dim) in enumerate(zip(input_shape, output_shape)):
if input_dim != output_dim:
assert input_dim == IntImm(
1
), "Unexpected shapes! Input: {}, output: {}".format(
input_shape, output_shape
)
input_strides.append(input_shape[i:])
output_strides.append(output_shape[i:])
output_num_elements.append(output_shape[start_idx:])
start_idx = i + 1
if start_idx < len(output_shape):
input_strides.append([IntImm(1)])
output_strides.append([IntImm(1)])
output_num_elements.append(output_shape[start_idx:])
res = []
for (output_num_element, output_stride, input_stride) in zip(
output_num_elements, output_strides, input_strides
):
res.append(
"{} % ({}) / ({}) * ({})".format(
"idx * N_ELEMENTS_PER_THREAD",
_gen_int_var_product_str(output_num_element),
_gen_int_var_product_str(output_stride),
_gen_int_var_product_str(input_stride),
)
)
return " + ".join(res)
def _gen_input_broadcast_size_str(
input_broadcast_sizes: List[List[IntVar]],
output_shape: List[IntVar],
) -> List[str]:
res = []
for input_broadcast_size in input_broadcast_sizes:
if input_broadcast_size is None:
res.append("")
else:
res.append(
_gen_input_broadcast_calculator_str(input_broadcast_size, output_shape)
)
return res
def _gen_dynamic_dim_str(
index_type: str, dynamic_dims: List[IntVar], has_type: bool
) -> str:
type_str = index_type + " " if has_type else ""
res = ", ".join([type_str + dim._attrs["name"] for dim in dynamic_dims])
if res:
res += ", "
return res
def _gen_read_inputs_str(
fused_elementwise_metadata: FusedElementwiseMetaData, broadcast_sizes: List[str]
):
read_inputs = []
for input_idx, (input_accessor, broadcast_size) in enumerate(
zip(fused_elementwise_metadata.input_accessors, broadcast_sizes)
):
input_name = f"input_tmp{input_idx}"
data_idx = (
"idx"
if not broadcast_size
else f"({broadcast_size}) / N_ELEMENTS_PER_THREAD"
)
get_strided_addr_str = GET_STRIDED_ADDRESS_TEMPLATE.render(
tensor_accessor=input_accessor,
data_ptr=input_name,
data_t=fused_elementwise_metadata.data_t,
read_t=fused_elementwise_metadata.read_t,
data_idx=data_idx,
)
read_input = KERNEL_READ_INPUT_TEMPLATE.render(
get_strided_address=get_strided_addr_str,
input_name=input_name,
input_idx=input_idx,
read_t=fused_elementwise_metadata.read_t,
op_t=fused_elementwise_metadata.op_t,
)
read_inputs.append(read_input)
read_inputs_str = "\n".join(read_inputs)
return read_inputs_str
def _gen_write_outputs_str(fused_elementwise_metadata: FusedElementwiseMetaData):
write_outputs = []
for output_idx, output_accessor in enumerate(
fused_elementwise_metadata.output_accessors
):
output_name = f"output{output_idx}"
get_strided_addr_str = GET_STRIDED_ADDRESS_TEMPLATE.render(
tensor_accessor=output_accessor,
data_ptr=output_name,
data_t=fused_elementwise_metadata.data_t,
read_t=fused_elementwise_metadata.read_t,
data_idx="idx",
)
write_out = KERNEL_WRITE_OUTPUT_TEMPLATE.render(
get_strided_address=get_strided_addr_str,
output_name=output_name,
output_idx=output_idx,
)
write_outputs.append(write_out)
write_outputs_str = "\n".join(write_outputs)
return write_outputs_str
def _gen_kernel_function(
func_attrs: Dict[str, Any],
index_type: str,
fused_elementwise_metadata: FusedElementwiseMetaData,
backend_datatype_convertors: Dict[str, Dict[str, str]],
) -> str:
output_params_decl = ",".join(
[
KERNEL_DECL_OUTPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
input_params_decl = ",".join(
[
KERNEL_DECL_INPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
broadcast_sizes = _gen_input_broadcast_size_str(
fused_elementwise_metadata.input_broadcast_sizes,
fused_elementwise_metadata.output_accessors[0].original_shapes,
)
read_inputs_str = _gen_read_inputs_str(fused_elementwise_metadata, broadcast_sizes)
define_outputs = KERNEL_DEFINE_OUTPUTS_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t,
op_t=fused_elementwise_metadata.op_t,
indexes=list(range(len(fused_elementwise_metadata.outputs))),
)
write_outputs_str = _gen_write_outputs_str(fused_elementwise_metadata)
input_names = [
KERNEL_TMP_INPUT_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
output_names = [
KERNEL_TMP_OUTPUT_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
fused_funcs = gen_function_single_thread(
fused_elementwise_metadata,
input_names,
output_names,
backend_datatype_convertors,
)
kernel_func = KERNEL_TEMPLATE.render(
func_name=func_attrs["name"],
index_type=index_type,
output_params=output_params_decl,
input_params=input_params_decl,
dynamic_dims=_gen_dynamic_dim_str(
index_type, fused_elementwise_metadata.dynamic_dims, has_type=True
),
read_inputs=read_inputs_str,
define_outputs=define_outputs,
write_outputs=write_outputs_str,
fused_funcs=fused_funcs,
)
return kernel_func
def fused_elementwise_gen_function(
func_attrs: Dict[str, Any],
custom_libs: str,
head_template: str,
backend_spec: BackendSpec,
) -> str:
"""Generates fused_elementwise function definition."""
ops = func_attrs["elementwise_ops"]
inputs = func_attrs["inputs"]
outputs = func_attrs["outputs"]
input_accessors = func_attrs["input_accessors"]
output_accessors = func_attrs["output_accessors"]
original_inputs = func_attrs["original_inputs"]
original_outputs = func_attrs["original_outputs"]
fused_elementwise_metadata = _parse_func_metadata(
ops,
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
backend_spec,
)
# Dump data types into func_attr for testing purpose.
func_attrs["read_t"] = fused_elementwise_metadata.read_t
func_attrs["op_t"] = fused_elementwise_metadata.op_t
func_attrs["data_t"] = fused_elementwise_metadata.data_t
tensor_accessor_lib = tensor_accessor_codegen.get_libs()
tensor_accessor_lib_str = "\n\n" + tensor_accessor_lib + "\n\n"
kernel_function = _gen_kernel_function(
func_attrs,
backend_spec.index_type,
fused_elementwise_metadata,
backend_spec.backend_datatype_convertors,
)
output_params_decl = ",".join(
[
FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
input_params_decl = ",".join(
[
FUNC_DECL_INPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
kernel_call_output_params = ",".join(
[
KERNEL_CALL_OUTPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
kernel_call_input_params = ",".join(
[
KERNEL_CALL_INPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
constant = CONSTANT_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t,
op_t=fused_elementwise_metadata.op_t,
data_t=fused_elementwise_metadata.data_t,
)
function = FUNC_TEMPLATE.render(
prefix=backend_spec.prefix,
index_type=backend_spec.index_type,
head=backend_spec.header_src_template.render(extra_header=head_template),
constant=constant,
custom_libs=custom_libs,
tensor_accessor_lib=tensor_accessor_lib_str,
kernel_function=kernel_function,
func_name=func_attrs["name"],
output_params=output_params_decl,
input_params=input_params_decl,
dynamic_dims_decl=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=True,
),
dynamic_dims_call=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=False,
),
kernel_call_output_params=kernel_call_output_params,
kernel_call_input_params=kernel_call_input_params,
)
return function
def fused_elementwise_gen_function_decl(
func_attrs,
backend_spec: BackendSpec,
):
"""Generates fused_elementwise function declaration."""
func_name = func_attrs["name"]
ops = func_attrs["elementwise_ops"]
inputs = func_attrs["inputs"]
outputs = func_attrs["outputs"]
input_accessors = func_attrs["input_accessors"]
output_accessors = func_attrs["output_accessors"]
original_inputs = func_attrs["original_inputs"]
original_outputs = func_attrs["original_outputs"]
fused_elementwise_metadata = _parse_func_metadata(
ops,
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
backend_spec,
)
output_params_decl = ",".join(
[
FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
input_params_decl = ",".join(
[
FUNC_DECL_INPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
function_decl = FUNC_DECL_TEMPLATE.render(
prefix=backend_spec.prefix,
index_type=backend_spec.index_type,
func_name=func_name,
output_params=output_params_decl,
input_params=input_params_decl,
dynamic_dims=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=True,
),
)
return function_decl
def fused_elementwise_gen_function_call(
func_attrs,
indent: str,
backend_spec: BackendSpec,
):
"""Generates fused_elementwise function call."""
ops = func_attrs["elementwise_ops"]
inputs = func_attrs["inputs"]
outputs = func_attrs["outputs"]
input_accessors = func_attrs["input_accessors"]
output_accessors = func_attrs["output_accessors"]
original_inputs = func_attrs["original_inputs"]
original_outputs = func_attrs["original_outputs"]
fused_elementwise_metadata = _parse_func_metadata(
ops,
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
backend_spec,
)
output_params = ",".join([output._attrs["name"] for output in outputs])
input_params = ",".join([input._attrs["name"] for input in inputs])
num_elements_calculator = _gen_int_var_product_str(
output_accessors[0].original_shapes
)
return FUNC_CALL_TEMPLATE.render(
stream=backend_spec.stream,
func_name=func_attrs["name"],
index_type=backend_spec.index_type,
calculate_n=num_elements_calculator,
output_params=output_params,
input_params=input_params,
dynamic_dims=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=False,
),
indent=indent,
)
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Common template for gemm
"""
import os
import re
from collections import OrderedDict
from hashlib import sha1
import jinja2
from ...common import gemm_common
from ...target import Target
# pylint: disable=C0103,C0415,W0611,C0301
EXTRA_SHAPE_TEMPLATE = jinja2.Template(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
"""
)
INSTANCE_TEMPLATE = jinja2.Template(
"""
{{config}}
using {{name}} = {{ config_name }};
"""
)
EXEC_TEMPLATE = jinja2.Template(
"""
{{indent}}auto op = {{instance}}{};
{{indent}}auto invoker = op.MakeInvoker();
{{indent}}auto argument = op.MakeArgument(
{{problem_args}}
{{indent}});
{{indent}}if(!op.IsSupportedArgument(argument)) {
{{indent}} throw std::runtime_error(
{{indent}} "wrong! device_gemm with the specified compilation parameters does "
{{indent}} "not support this Gemm problem");
{{indent}}}
{% if is_profiler %}
{{indent}}auto workspace_size = op.GetWorkSpaceSize(&argument);
{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size;
{% endif %}
{{indent}}invoker.Run(argument, StreamConfig{stream, false});
{{indent}}return;
"""
)
EXTRA_HEADER_TEMPLATE = jinja2.Template(
"""
{% if gemm_flag == "" %}
#include "include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
{% elif gemm_flag == "permute_m2n3" %}
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% elif "bias" in gemm_flag or has_d0 %}
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
{% if gemm_flag == "bias_permute" %}
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
{% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %}
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% endif %}
"""
)
SRC_TEMPLATE = jinja2.Template(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
{{extra_header}}
{{extra_code}}
{{instances}}
void {{function_name}}(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
{% if "bias" in gemm_flag %}
void * bias_ptr,
{% endif %}
{% if has_d0 %}
void * d0_ptr,
{% endif %}
{% if has_d1 %}
void * d1_ptr,
{% endif %}
{% for idx in range(ndims) %}
int64_t* a_dim{{idx}},
{% endfor %}
{% for idx in range(ndims) %}
int64_t* b_dim{{idx}},
{% endfor %}
{% for idx in range(ndims) %}
int64_t* c_dim{{idx}},
{% endfor %}
{% for idx in range(pdims) %}
const int p_dim{{idx}},
{% endfor %}
hipStream_t stream
) {
{{shape_func}}
{{extra_shape}}
{{input_addr_calculator}}
{{output_addr_calculator}}
{{exec_paths}}
throw std::runtime_error(
"Unsupported workload for this gemm specialization."
);
}
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{in_ptr}},
{{indent}} {{weight_ptr}},
{{indent}} {{out_ptr}},
{% if "bias" in gemm_flag %}
{{indent}} {{bias_ptr}},
{% endif %}
{% if d0_ptr != "" %}
{{indent}} {{d0_ptr}},
{% endif %}
{% if d1_ptr != "" %}
{{indent}} {{d1_ptr}},
{% endif %}
{% for dim in adims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in bdims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in cdims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in pdims %}
{{indent}} {{dim}},
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROBLEM_ARGS_TEMPLATE = jinja2.Template(
"""
{{indent}} static_cast<ck::half_t *>(in_ptr),
{{indent}} static_cast<ck::half_t *>(weight_ptr),
{% if gemm_flag == "bias_permute" %}
{{indent}} static_cast<ck::half_t *>(bias_ptr),
{% elif gemm_flag == "bias_permute_m2n3" %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif gemm_flag == "permute_m2n3" %}
{{indent}} {},
{% else %}
{% if "bias" in gemm_flag and not has_d0 %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif has_d0 and not has_d1 %}
{{indent}} std::array<const void*, 2>{static_cast<ck::half_t *>(bias_ptr),
static_cast<ck::half_t *>(d0_ptr)},
{% elif has_d1 %}
{{indent}} std::array<const void*, 3>{static_cast<ck::half_t *>(bias_ptr),
static_cast<ck::half_t *>(d0_ptr),
static_cast<ck::half_t *>(d1_ptr)},
{% endif %}
{% endif %}
{{indent}} static_cast<ck::half_t *>(out_ptr),
{% if gemm_flag not in ["permute_m2n3", "bias_permute_m2n3", "bias_permute_m3n2"] %}
{{indent}} M,
{{indent}} N,
{{indent}} K,
{{indent}} stride_a,
{{indent}} stride_b,
{% endif %}
{% if gemm_flag == "bias_permute" %}
{{indent}} {M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1},
{{indent}} {M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1},
{% elif gemm_flag in ["permute_m2n3", "bias_permute_m2n3", "bias_permute_m3n2"] %}
{{indent}} a_ms_ks_lengths,
{{indent}} a_ms_ks_strides,
{{indent}} b_ns_ks_lengths,
{{indent}} b_ns_ks_strides,
{% if gemm_flag == "permute_m2n3" %}
{{indent}} {},
{{indent}} {},
{% else %}
{{indent}} std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
{{indent}} std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
{% endif %}
{{indent}} e_ms_ns_lengths,
{{indent}} e_ms_ns_strides,
{% else %}
{% if "bias" in gemm_flag and not has_d0 %}
{{indent}} std::array<ck::index_t, 1>{0},
{% elif has_d0 and not has_d1 %}
{{indent}} std::array<ck::index_t, 2>{0, static_cast<int>(stride_c)},
{% elif has_d1 %}
{{indent}} std::array<ck::index_t, 3>{0, static_cast<int>(stride_c), static_cast<int>(stride_c)},
{% endif %}
{{indent}} stride_c,
{% endif %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{},
{{indent}} ck::tensor_operation::element_wise::PassThrough{},
{% if gemm_flag == "" %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{}
{% elif gemm_flag == "permute_m2n3" %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{}
{% elif gemm_flag == "bias" or "bias_permute" in gemm_flag %}
{{indent}} ck::tensor_operation::element_wise::Add{}
{% elif gemm_flag == "bias_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddRelu{}
{% elif gemm_flag == "bias_fast_gelu" %}
{{indent}} ck::tensor_operation::element_wise::AddFastGelu{}
{% elif gemm_flag == "bias_swish" %}
{{indent}} ck::tensor_operation::element_wise::AddHardswish{}
{% elif gemm_flag == "bias_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddTanh{}
{% elif gemm_flag == "bias_sigmoid" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoid{}
{% elif gemm_flag == "bias_add" %}
{{indent}} ck::tensor_operation::element_wise::AddAdd{}
{% elif gemm_flag == "bias_mul" %}
{{indent}} ck::tensor_operation::element_wise::AddMul{}
{% elif gemm_flag == "bias_mul_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddMulTanh{}
{% elif gemm_flag == "bias_add_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddRelu{}
{% elif gemm_flag == "bias_add_fast_gelu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddFastGelu{}
{% elif gemm_flag == "bias_sigmoid_mul" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoidMul{}
{% elif gemm_flag == "bias_sigmoid_mul_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoidMulTanh{}
{% elif gemm_flag == "bias_mul_add" %}
{{indent}} ck::tensor_operation::element_wise::AddMulAdd{}
{% elif gemm_flag == "bias_add_add" %}
{{indent}} ck::tensor_operation::element_wise::AddAddAdd{}
{% elif gemm_flag == "bias_add_add_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddAddRelu{}
{% endif %}
"""
)
TENSOR_DECL_TEMPLATE = jinja2.Template(
"""
int64_t a_ptr_sz = M*K;
int64_t b_ptr_sz = N*K;
int64_t c_ptr_sz = M*N;
int64_t ptr_max_sz = std::max({a_ptr_sz, b_ptr_sz, c_ptr_sz});
// TODO: special pool size for 8M L2 cache
// need to tune it for other devices
int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 23) / ptr_max_sz)));
memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // x: index 0
memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // w: index 1
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // y: index 2
{% if "bias" in gemm_flag %}
memory_pool->AllocateHalfTensor(N, mem_pool_sz); // b: index 3
{% endif %}
{% if has_d0 %}
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d0 ptr: index 4
{% endif %}
{% if has_d1 %}
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d1 ptr: index 5
{% endif %}
"""
)
STRUCTS_DEF_TEMPLATE = jinja2.Template(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
PROFILER_TEMPLATE = jinja2.Template(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
if (argc < 4) {
throw std::runtime_error("wrong params");
}
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// TODO: random init
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
auto timer = new KernelTimerImpl();
timer->Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer->End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer->GetElapsedTime() << std::endl;
delete(timer);
}
"""
)
FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void {{func_name}}(
void *,
void *,
void *,
{% if "bias" in gemm_flag %}
void *,
{% endif %}
{% if has_d0 %}
void *,
{% endif %}
{% if has_d1 %}
void *,
{% endif %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(pdims) %}
const int,
{% endfor %}
hipStream_t
);
"""
)
def has_d0(func_attrs):
return func_attrs.get("num_sources", 0) >= 1
def has_d1(func_attrs):
return func_attrs.get("num_sources", 0) >= 2
def emit_instance(op):
"""Emit instance."""
import ck_lib # noqa: F401
op_def = op.emit()
return op_def
def extract_config(op_kind, extra_kind, f_proc_op):
"""Extract (operation name, operation instance) pair
from all operation candidates.
Parameters
----------
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
f_prop_op: function
Used to filter operation.
Returns
-------
Dict
Extracted (operation name, operation instance) pair.
"""
gemm_ops = OrderedDict()
extract_ops = list(Target.current()._operators[op_kind][extra_kind].items())
for key, value in extract_ops:
op = value[0]
ret = f_proc_op(op)
if len(ret) > 0:
for op_inst in ret:
gemm_ops[key] = op_inst
return gemm_ops
def extract_config_name(config):
"""Exract name from the statement, e.g. 'model' for 'using model = xxx'.
Parameters
----------
config : str
Configuration as a string in the format of 'using model = xxx'.
Returns
-------
str
Extracted name from the statement.
Raises
------
RuntimeError
Invalid config.
"""
pattern = re.compile(r"\s*using\s(.*?)\s=")
decl = config.split("\n")[1]
match = pattern.match(decl)
if match is None:
raise RuntimeError("Invalid config: \n" + config)
return match.groups()[0]
def gen_profiler(
func_attrs,
workdir,
dim_info_dict,
args_parse,
gemm_flag,
extra_code="",
ndims=2,
extra_shape_template=EXTRA_SHAPE_TEMPLATE,
problem_args_template=PROBLEM_ARGS_TEMPLATE,
extra_header_template=EXTRA_HEADER_TEMPLATE,
tensor_decl_template=TENSOR_DECL_TEMPLATE,
):
"""Generates standalone executables for profiler.
Parameters
----------
func_attrs : Dict
Operation attributes.
workdir : str
Directory to store the generated outputs.
dim_info_dict: Dict[str, DimInfo]
Generated from gemm._extract_dims().
Used to store mapping between dim_names to input / output tensor dims.
args_parse: str
Profiler input argument parser.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu','bias_sigmoid','bias_add_relu'.
extra_code : str
Extra code for self-defined operators.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm
extra_shape_template: jinja2.Template
Shape evaluation template.
problem_args_template: jinja2.Template
Problem args template for profiler.
extra_header_template: jinja2.Template
Extra header template as we have different headers for gemm and bmm.
tensor_decl_template: jinja2.Template
Tensor declaration template.
"""
op_type = func_attrs["op"]
op_instance = func_attrs["op_instance"]
# shape function
op_func_shape = gemm_common.gen_shape_eval_code(
indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True
)
adims = ["&a_dim" + str(i) for i in range(ndims)]
bdims = ["&b_dim" + str(i) for i in range(ndims)]
cdims = ["&c_dim" + str(i) for i in range(ndims)]
pdims = []
if func_attrs.get("shape") is not None:
pdims = ["p_dim" + str(i) for i in range(len(func_attrs["shape"]))]
extra_shape_func = extra_shape_template.render(indent=" ")
file_pairs = []
has_d0_flag = has_d0(func_attrs)
has_d1_flag = has_d1(func_attrs)
for op_name, op in op_instance.items():
config = emit_instance(op)
config_name = extract_config_name(config)
instance = INSTANCE_TEMPLATE.render(
name="DeviceGemmInstance", config_name=config_name, config=config
)
problem_args = problem_args_template.render(
indent=" ",
gemm_flag=gemm_flag,
has_d0=has_d0_flag,
has_d1=has_d1_flag,
)
exec_program = EXEC_TEMPLATE.render(
indent=" ",
instance="DeviceGemmInstance",
problem_args=problem_args,
is_profiler=True,
)
extra_header = extra_header_template.render(
gemm_flag=gemm_flag, has_d0=has_d0_flag
)
op_func = SRC_TEMPLATE.render(
instances=instance,
function_name="gemm",
ndims=ndims,
pdims=len(pdims),
has_d0=has_d0_flag,
has_d1=has_d1_flag,
shape_func=op_func_shape,
extra_shape=extra_shape_func,
exec_paths=exec_program,
extra_code=extra_code,
gemm_flag=gemm_flag,
extra_header=extra_header,
)
structs_def = STRUCTS_DEF_TEMPLATE.render()
tensor_decl = tensor_decl_template.render(
gemm_flag=gemm_flag, has_d0=has_d0_flag, has_d1=has_d1_flag
)
func_call = FUNC_CALL_TEMPLATE.render(
func_name="gemm",
in_ptr="(void *) memory_pool->RequestHalfTensorByIdx(0)",
weight_ptr="(void *) memory_pool->RequestHalfTensorByIdx(1)",
out_ptr="(void *) memory_pool->RequestHalfTensorByIdx(2)",
bias_ptr="(void *) memory_pool->RequestHalfTensorByIdx(3)",
d0_ptr="(void *) memory_pool->RequestHalfTensorByIdx(4)"
if has_d0_flag
else "",
d1_ptr="(void *) memory_pool->RequestHalfTensorByIdx(5)"
if has_d1_flag
else "",
adims=adims,
bdims=bdims,
cdims=cdims,
pdims=pdims,
gemm_flag=gemm_flag,
)
code = PROFILER_TEMPLATE.render(
structs_def=structs_def,
op_func=op_func,
args_parse=args_parse,
tensor_decl=tensor_decl,
func_call=func_call,
)
prefix = os.path.join(workdir, "profiler", op_type)
if not os.path.exists(prefix):
os.makedirs(prefix)
src_path = os.path.join(prefix, op_name + ".cpp")
obj_path = os.path.join(prefix, op_name)
if os.path.exists(obj_path):
continue
with open(src_path, "w") as fo:
fo.write(code)
file_pairs.append((src_path, obj_path))
return file_pairs
def gen_function(
func_attrs,
exec_cond_template,
dim_info_dict,
gemm_flag,
extra_code="",
ndims=2,
extra_shape_template=EXTRA_SHAPE_TEMPLATE,
problem_args_template=PROBLEM_ARGS_TEMPLATE,
extra_header_template=EXTRA_HEADER_TEMPLATE,
input_addr_calculator="",
output_addr_calculator="",
):
"""Generate function body.
Parameters
----------
func_attrs : Dict
Operation attributes.
exec_cond_template : jinja2.Template
Generates if statement to execute kernel.
dim_info_dict: Dict[str, DimInfo]
Generated from gemm._extract_dims().
Used to store mapping between dim_names to input / output tensor dims.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu','bias_add_relu'.
extra_code : str
Extra code for self-defined operators.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm.
extra_shape_template: jinja2.Template
Shape evaluation template.
extra_header_template: jinja2.Template
Extra header template as we have different headers for gemm and bmm.
input_addr_calculator : str
Used to adjust input address based on input tensor accessors if accessors exist
output_addr_calculator : str
Used to adjust output address based on output tensor accessors if accessors exist
Returns
-------
str
The rendered template of generated function body.
"""
func_name = func_attrs["name"]
exec_path = func_attrs["exec_path"]
op_instance = func_attrs["op_instance"]
inst_def_flag = set()
instances = {}
instance_decl = ""
has_d0_flag = has_d0(func_attrs)
has_d1_flag = has_d1(func_attrs)
for key, value in exec_path.items():
fname = "f" + sha1(key.encode()).hexdigest()
algo = value.algo
if algo not in inst_def_flag:
config = emit_instance(op_instance[algo])
inst_def_flag.add(algo)
else:
config = ""
inst = INSTANCE_TEMPLATE.render(
config=config, name=fname, config_name=extract_config_name(config)
)
instances[key] = inst
instance_decl += inst
extra_shape_func = extra_shape_template.render(indent=" ")
shape_eval_func = gemm_common.gen_shape_eval_code(
indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True
)
exec_paths = ""
for key, _ in instances.items():
fname = "f" + sha1(key.encode()).hexdigest()
problem_args = problem_args_template.render(
indent=" ",
gemm_flag=gemm_flag,
has_d0=has_d0_flag,
has_d1=has_d1_flag,
)
program = EXEC_TEMPLATE.render(
indent=" ",
instance=fname,
problem_args=problem_args,
is_profiler=False,
)
exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program)
exec_paths += exec_inst
extra_header = extra_header_template.render(
gemm_flag=gemm_flag, has_d0=has_d0(func_attrs)
)
pdims = len(func_attrs["shape"]) if func_attrs.get("shape") is not None else 0
return SRC_TEMPLATE.render(
instances=instance_decl,
function_name=func_name,
shape_func=shape_eval_func,
extra_shape=extra_shape_func,
input_addr_calculator=input_addr_calculator,
output_addr_calculator=output_addr_calculator,
exec_paths=exec_paths,
extra_code=extra_code,
extra_header=extra_header,
gemm_flag=gemm_flag,
ndims=ndims,
pdims=pdims,
has_d0=has_d0_flag,
has_d1=has_d1_flag,
)
def gen_function_decl(func_name, gemm_flag, ndims=2, pdims=0, has_d0="", has_d1=""):
"""Generates function declarations.
Parameters
----------
func_attrs : Dict
Operation attributes.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu'.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm.
Returns
-------
str
The rentered template of function declaration.
"""
return FUNC_DECL_TEMPLATE.render(
func_name=func_name,
gemm_flag=gemm_flag,
ndims=ndims,
pdims=pdims,
has_d0=has_d0,
has_d1=has_d1,
)
def gen_function_call(func_attrs, indent=" ", gemm_flag=""):
"""Generates function call.
Parameters
----------
func_attrs : Dict
Stores the operation attributes.
indent : str, optional
Indent for codegen, target dependent e.g. C++, python, etc., by default " ".
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu'.
Returns
-------
str
The rendered template of generated function call.
"""
a = func_attrs["inputs"][0]
b = func_attrs["inputs"][1]
c = func_attrs["outputs"][0]
bias_ptr = ""
if "bias" in gemm_flag:
bias = func_attrs["inputs"][2]
bias_ptr = bias._attrs["name"]
d0_ptr = ""
if has_d0(func_attrs):
d0 = func_attrs["inputs"][3]
d0_ptr = d0._attrs["name"]
d1_ptr = ""
if has_d1(func_attrs):
d1 = func_attrs["inputs"][4]
d1_ptr = d1._attrs["name"]
adims = [
"&" + dim._attrs["name"]
for dim in func_attrs["input_accessors"][0].original_shapes
]
bdims = [
"&" + dim._attrs["name"]
for dim in func_attrs["input_accessors"][1].original_shapes
]
cdims = [
"&" + dim._attrs["name"]
for dim in func_attrs["output_accessors"][0].original_shapes
]
pdims = []
if func_attrs.get("shape") is not None:
pdims = list(func_attrs["shape"])
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
in_ptr=a._attrs["name"],
weight_ptr=b._attrs["name"],
out_ptr=c._attrs["name"],
bias_ptr=bias_ptr,
d0_ptr=d0_ptr,
d1_ptr=d1_ptr,
adims=adims,
bdims=bdims,
cdims=cdims,
pdims=pdims,
indent=indent,
gemm_flag=gemm_flag,
)
def default_fproc_f16(*, op, a_layout, b_layout, c_layout):
"""Filter the input operation by layouts.
Parameters
----------
op: operation
aitemplate operation
a_layout: ck_lib.library.LayoutType
a layout type.
b_layout: ck_lib.library.LayoutType
b layout type.
c_layout: ck_lib.library.LayoutType
c layout type.
Returns
-------
List
List of filtered op (can be empty).
"""
import copy
import ck_lib
ret = []
data_type = ck_lib.library.DataType.f16
acc_type = ck_lib.library.DataType.f32
if (
op.A.element == data_type
and op.B.element == data_type
and op.C.element == data_type
and op.accumulator_type() == acc_type
and op.A.layout == a_layout
and op.B.layout == b_layout
and op.C.layout == c_layout
):
ret += [copy.deepcopy(op)]
return ret
def make_fproc_f16(func_attrs, layout, op_kind, extra_kind):
"""This function sets a callback for processing the epilogue of the kernel
associated with func_attrs.
Parameters
----------
func_attrs: Dictionary
kernel attributes dictionary
layout: layout object
kernel layout
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
"""
def fproc_f16(op):
a_layout, b_layout, c_layout = layout.ck_lib_layouts()
return default_fproc_f16(
op=op,
a_layout=a_layout,
b_layout=b_layout,
c_layout=c_layout,
)
func_attrs["op_instance"] = extract_config(op_kind, extra_kind, fproc_f16)
\ No newline at end of file
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1,
256, // block_size
64, // m_per_block
128, // n_per_block
32, // k_per_block
8, // ak1
2, // bk1
32, // m_per_xdl
32, // n_per_xdl
1, // m_xdl_per_wave
2, // n_xdl_per_wave
ck::Sequence<4,64,1>, // thread_cluster_length
ck::Sequence<1,0,2>, // thread_cluster_arrange_order
ck::Sequence<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<8,32,1>, // thread_cluster_length
ck::Sequence<0,2,1>, // thread_cluster_arrange_order
ck::Sequence<0,2,1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
1, // n_xdl_per_wave
ck::Sequence<1,32,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector
>;
using fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0 = gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT;
void gemm_rrr_3(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
int64_t* a_dim0,
int64_t* a_dim1,
int64_t* b_dim0,
int64_t* b_dim1,
int64_t* c_dim0,
int64_t* c_dim1,
hipStream_t stream
) {
ck::index_t M = (*a_dim0);
ck::index_t N = (*b_dim1);
ck::index_t K = (*a_dim1);
int64_t offset_a = 0;
int64_t offset_b = 0;
int64_t offset_c = 0;
ck::index_t stride_a = *a_dim1;
ck::index_t stride_b = *b_dim1;
ck::index_t stride_c = *c_dim1;
if (M == 256 && N == 32 && K == 128) {
auto op = fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(
static_cast<ck::half_t *>(in_ptr) + offset_a,
static_cast<ck::half_t *>(weight_ptr) + offset_b,
static_cast<ck::half_t *>(out_ptr) + offset_c,
M,
N,
K,
stride_a,
stride_b,
stride_c,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}
);
if(!op.IsSupportedArgument(argument)) {
LOG(FATAL) << "wrong! " << op.GetTypeString() << " with the specified compilation parameters does not support this Gemm problem.";
}
invoker.Run(argument, StreamConfig{stream, false});
return;
}
LOG(FATAL) << "Unsupported workload for this gemm specialization.";
}
\ No newline at end of file
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1,
256, // block_size
64, // m_per_block
128, // n_per_block
32, // k_per_block
8, // ak1
2, // bk1
32, // m_per_xdl
32, // n_per_xdl
1, // m_xdl_per_wave
2, // n_xdl_per_wave
ck::Sequence<4,64,1>, // thread_cluster_length
ck::Sequence<1,0,2>, // thread_cluster_arrange_order
ck::Sequence<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<8,32,1>, // thread_cluster_length
ck::Sequence<0,2,1>, // thread_cluster_arrange_order
ck::Sequence<0,2,1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
1, // n_xdl_per_wave
ck::Sequence<1,32,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector
>;
using fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0 = gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT;
void gemm_rrr_3(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
int64_t* a_dim0,
int64_t* a_dim1,
int64_t* b_dim0,
int64_t* b_dim1,
int64_t* c_dim0,
int64_t* c_dim1,
hipStream_t stream
) {
ck::index_t M = (*a_dim0);
ck::index_t N = (*b_dim1);
ck::index_t K = (*a_dim1);
int64_t offset_a = 0;
int64_t offset_b = 0;
int64_t offset_c = 0;
ck::index_t stride_a = *a_dim1;
ck::index_t stride_b = *b_dim1;
ck::index_t stride_c = *c_dim1;
if (M == 256 && N == 32 && K == 128) {
auto op = fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(
static_cast<ck::half_t *>(in_ptr) + offset_a,
static_cast<ck::half_t *>(weight_ptr) + offset_b,
static_cast<ck::half_t *>(out_ptr) + offset_c,
M,
N,
K,
stride_a,
stride_b,
stride_c,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}
);
if(!op.IsSupportedArgument(argument)) {
LOG(FATAL) << "wrong! " << op.GetTypeString() << " with the specified compilation parameters does not support this Gemm problem.";
}
invoker.Run(argument, StreamConfig{stream, false});
return;
}
LOG(FATAL) << "Unsupported workload for this gemm specialization.";
}
\ No newline at end of file
size_t GLOBAL_WORKSPACE_SIZE = 0;
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
using layernorm_rank2_256_1_256_1_8_1_8_1_8_1_8_8 = ck::tensor_operation::device::DeviceLayernormImpl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
2,
1,
256, // block_size
1, // m_cluster_size
256, // k_cluster_size
1, // m_slice_size
8, // k_slice_size
1, // in_src_dim
8, // in_src_size
1, // gamma_src_dim
8, // gamma_src_size
1, // beta_src_dim
8, // beta_src_size
8 // out_dst_size
>;
using DeviceInstance = layernorm_rank2_256_1_256_1_8_1_8_1_8_1_8_8;
void layernorm_0(void* input,
void* gamma,
void* beta,
void* output,
int64_t* in_0,
int64_t* in_1,
hipStream_t stream)
{
int M = *in_0;
int N = *in_1;
std::vector<ck::index_t> i_inStrides;
i_inStrides.push_back(N);
i_inStrides.push_back(1);
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer(
{M, N},
i_inStrides,
std::vector<ck::index_t>{0, 1},
std::vector<ck::index_t>{0, 1},
i_inStrides,
{1},
1e-05,
static_cast<ck::half_t *>(input),
static_cast<ck::half_t *>(gamma),
static_cast<ck::half_t *>(beta),
static_cast<ck::half_t *>(output),
ck::tensor_operation::element_wise::PassThrough{}
);
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_layernorm with the specified compilation parameters does "
"not support this Softmax problem");
};
std::string instance_name = device_instance.GetTypeString();
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false});
return;
}
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
int main(int argc, char** argv) {
const int64_t in_0 = std::stoi(argv[1]);
const int64_t in_1 = std::stoi(argv[2]);
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
int64_t ptr_sz = in_0 * in_1;
int64_t norm_dim = in_1;
// TODO: special pool size for 8M L2 cache
// need to tune it for other devices
int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 23) / ptr_sz)));
memory_pool->AllocateHalfTensor(ptr_sz, mem_pool_sz); // in: index 0
memory_pool->AllocateHalfTensor(ptr_sz, mem_pool_sz); // out: index 1
memory_pool->AllocateHalfTensor(norm_dim, mem_pool_sz); // gamma: index 2
memory_pool->AllocateHalfTensor(norm_dim, mem_pool_sz); // beta: index 3
// warmup
for(int i = 0; i < 3; ++i) {
layernorm_0(
(void *) memory_pool->RequestHalfTensorByIdx(0),
(void *) memory_pool->RequestHalfTensorByIdx(2),
(void *) memory_pool->RequestHalfTensorByIdx(3),
(void *) memory_pool->RequestHalfTensorByIdx(1),
const_cast<int64_t *>(&in_0),
const_cast<int64_t *>(&in_1),
stream
);
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
layernorm_0(
(void *) memory_pool->RequestHalfTensorByIdx(0),
(void *) memory_pool->RequestHalfTensorByIdx(2),
(void *) memory_pool->RequestHalfTensorByIdx(3),
(void *) memory_pool->RequestHalfTensorByIdx(1),
const_cast<int64_t *>(&in_0),
const_cast<int64_t *>(&in_1),
stream
);
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
\ No newline at end of file
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <string>
#include <unordered_map>
#include <math.h>
void gemm_rrr_3(
void *,
void *,
void *,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
hipStream_t
);
#define CHECK_VECTOR_ACCESS(vector, idx) \
if (idx >= vector.size()) { \
throw std::out_of_range( \
"[__func__]: index out of range, " #vector ".size()=" + \
std::to_string(vector.size()) + ", got " + std::to_string(idx)); \
}
namespace ait {
namespace {
void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") + GetLastErrorString() +
" enum: " + std::to_string(device_error) +
" at " + file + ": " + std::to_string(line);
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
thread_local bool target_has_graph_mode = false;
} // namespace
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class Model {
public:
Model(
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
uint8_t* constants,
AITemplateAllocator& allocator)
: blob_(RAII_DeviceMalloc(blob_size, allocator)),
workspace_(RAII_DeviceMalloc(workspace_size, allocator)),
params_(num_inputs + num_outputs + num_unbound_constants),
num_inputs_(num_inputs),
num_outputs_(num_outputs),
constants_(constants) {
dmlc::InitLogging("aitemplate"); // TODO(xxx): render network name
LOG(INFO) << "Init AITemplate Runtime.";
global_workspace_ = static_cast<uint8_t*>(workspace_.get()) + 0;
unique_workspace_ = static_cast<uint8_t*>(workspace_.get());
DEVICE_CHECK(GetDevice(&device_idx_))
DEVICE_CHECK(CreateEvent(&run_finished_));
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
DEVICE_CHECK(cudaDeviceGetAttribute(
&max_smem_size_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx_));
#endif
DEVICE_CHECK(GetDeviceProperties(&device_properties_, device_idx_));
DEVICE_CHECK(StreamCreate(&graph_capture_stream_, /*non_blocking=*/true));
InitConstants(constants_);
auto* blob_ptr = static_cast<uint8_t*>(blob_.get());
}
~Model() {
if (run_finished_ != nullptr) {
DestroyEvent(run_finished_);
}
if (graph_capture_stream_ != nullptr) {
StreamDestroy(graph_capture_stream_);
}
if (graph_exec_ != nullptr) {
GraphExecDestroy(graph_exec_);
}
}
Model(Model&& other) {
run_finished_ = other.run_finished_;
graph_exec_ = other.graph_exec_;
graph_capture_stream_ = other.graph_capture_stream_;
other.run_finished_ = nullptr;
other.graph_exec_ = nullptr;
other.graph_capture_stream_ = nullptr;
constants_ = other.constants_;
num_inputs_ = other.num_inputs_;
global_workspace_ = other.global_workspace_;
unique_workspace_ = other.unique_workspace_;
workspace_ = std::move(other.workspace_);
params_ = std::move(other.params_);
constant_name_to_ptr_ = std::move(other.constant_name_to_ptr_);
// Re-wire the pointers in the above 2 structures.
InitConstants(constants_);
}
Model& operator=(Model&&) = delete;
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
void SetUpInputsOutputs() {
input_0 = static_cast<decltype(input_0)>(params_[0].ptr);
if (input_0 == nullptr) {
throw std::runtime_error("Constant input_0 was not set! Set the value with set_constant.");
}
input_1 = static_cast<decltype(input_1)>(params_[1].ptr);
if (input_1 == nullptr) {
throw std::runtime_error("Constant input_1 was not set! Set the value with set_constant.");
}
output_0 = static_cast<decltype(output_0)>(params_[2].ptr);
if (output_0 == nullptr) {
throw std::runtime_error("Constant output_0 was not set! Set the value with set_constant.");
}
}
void DeviceToDeviceCopies(StreamType stream) {
}
void Run(StreamType stream, bool graph_mode) {
SetUpInputsOutputs();
if (target_has_graph_mode && graph_mode) {
RunAsGraph(stream);
} else {
RunImpl(stream);
}
DEVICE_CHECK(EventRecord(run_finished_, stream));
}
void RunImpl(StreamType stream) {
gemm_rrr_3(
input_0,
input_1,
output_0,
&input_0_dim_0,
&input_0_dim_1,
&input_1_dim_0,
&input_1_dim_1,
&input_0_dim_0,
&input_1_dim_1,
stream
);
DeviceCheckLastError(__FILE__, __LINE__);
DeviceToDeviceCopies(stream);
}
bool IsPending() {
auto query = QueryEvent(run_finished_);
if (query == GetDeviceNotReady()) {
return true;
}
if (query != GetDeviceSuccess()) {
LOG(WARNING) << "Pending model run did not finish successfully. Error: "
<< GetErrorString(query);
}
return false;
}
void WaitForCompletion() {
DEVICE_CHECK(EventSynchronize(run_finished_));
}
size_t NumInputs() const {
return num_inputs_;
}
size_t NumOutputs() const {
return num_outputs_;
}
void SetParam(const void* src, size_t param_idx) {
CHECK_VECTOR_ACCESS(params_, param_idx)
// const_cast is not ideal here, but it is unfortunately
// necessary:
// 1) We store outputs and inputs in the same vector,
// and outputs cannot be const.
// 2) Most of the codegen is not const-correct (most ops
// require non-const pointers). So even if we put const
// pointers into params, a const_cast would be required
// somewhere else.
params_[param_idx].ptr = const_cast<void*>(src);
}
void SetInput(const void* src, const AITemplateParamShape& shape, size_t idx) {
SetInputShape(shape, idx);
SetParam(src, idx);
}
void SetOutput(void* src, size_t idx) {
SetParam(src, idx + num_inputs_);
}
// Write the (possibly dynamic) output shape to the given pointer.
// Note that this should be called _after_ the shape inference in
// Run() is finished. output_shape_out should be able to store
// at least GetOutputMaximumShape(idx).size values.
void GetOutputShape(size_t idx, int64_t* output_shape_out) {
const auto param_idx = idx + num_inputs_;
CHECK_VECTOR_ACCESS(params_, param_idx);
const auto& shape_ptrs = params_[param_idx].shape_ptrs;
for (size_t i = 0; i < shape_ptrs.size(); ++i) {
output_shape_out[i] = shape_ptrs[i].GetValue();
}
}
void SetConstant(const char* name, const void* src) {
auto it = constant_name_to_ptr_.find(name);
if (it == constant_name_to_ptr_.end()) {
throw std::out_of_range(std::string("Could not find constant ") + name);
}
const void** ptr = it->second;
*ptr = src;
}
private:
void InitConstants(uint8_t* constants) {
params_[0].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(128, 128, &input_0_dim_1)};
params_[1].shape_ptrs = {ParamDim(128, 128, &input_1_dim_0), ParamDim(32, 32, &input_1_dim_1)};
params_[2].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(32, 32, &input_1_dim_1)};
}
void SetInputShape(const AITemplateParamShape& shape, size_t idx) {
auto& param = params_[idx];
if (shape.size != param.shape_ptrs.size()) {
throw std::runtime_error(
"[SetInputShape] Got wrong param shape for input " + std::to_string(idx) +
"; expected " + std::to_string(param.shape_ptrs.size()) + ", got " +
std::to_string(shape.size));
}
for (size_t i = 0; i < param.shape_ptrs.size(); ++i) {
param.shape_ptrs[i].SetValue(shape.shape_data[i]);
}
}
DeviceError EndCapture(GraphType* graph_ptr) {
auto err = StreamEndCapture(graph_capture_stream_, graph_ptr);
if (err != GetDeviceSuccess()) {
// If we can't take the stream out of capture mode, something is probably
// wrong with CUDA graph for this model (e.g. there might have been an
// illegal capture mode operation). Disable graph mode to avoid such issues
// in future iterations.
target_has_graph_mode = false;
LOG(WARNING) << "Graph capture failed to end. Disabling graph mode.";
return err;
}
return GetDeviceSuccess();
}
void RunAsGraph(StreamType stream) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING) << "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr){ return EndCapture(graph_ptr); }
);
if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (GraphExecUpdate(graph_exec_, graph.get()) != GetDeviceSuccess()) {
// Consume the last cuda error, which may affect the next GraphExecLaunch
// call.
GetLastError();
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
int device_idx_;
int max_smem_size_{0};
DevicePropertyType device_properties_;
// This event tracks when the inference is finished
// so that this Model may be reclaimed by its owning
// ModelContainer.
EventType run_finished_;
// A blob of memory used for storing intermediate tensors.
GPUPtr blob_;
// Memory for constants that were folded into the *.so. Unowned by Model,
// owned by ModelContainer.
// TODO: make this const. It can't be const right now because we derive
// tensor pointers from it, and no tensor pointers are const.
uint8_t* constants_;
size_t num_inputs_;
size_t num_outputs_;
// The workspace blob is used as scratch memory. See
// _generate_workspace in memory planning for more information.
GPUPtr workspace_;
uint8_t* global_workspace_{nullptr};
uint8_t* unique_workspace_{nullptr};
class ParamDim {
public:
ParamDim(int64_t lower_bound, int64_t upper_bound, int64_t* value) :
lower_bound_(lower_bound),
upper_bound_(upper_bound),
value_(value) {}
void SetValue(int64_t new_value) {
if (new_value < lower_bound_ || new_value > upper_bound_) {
throw std::out_of_range(
"[SetValue] Dimension got value out of bounds; expected value to be in [" +
std::to_string(lower_bound_) + ", " + std::to_string(upper_bound_) + "], but got " +
std::to_string(new_value)
);
}
*value_ = new_value;
}
int64_t GetValue() const {
return *value_;
}
private:
int64_t lower_bound_;
int64_t upper_bound_;
int64_t* value_;
};
struct ParamInfo {
void* ptr = nullptr;
// TODO add offset
const char* name;
std::vector<ParamDim> shape_ptrs;
};
// Contains info for all tensors marked as inputs
// or outputs. The first num_inputs elements are the inputs.
// Constants are not included.
std::vector<ParamInfo> params_;
GraphExecType graph_exec_ = nullptr;
StreamType graph_capture_stream_;
std::unordered_map<std::string, const void**> constant_name_to_ptr_;
void * input_0 {nullptr};
void * input_1 {nullptr};
void * output_0 {nullptr};
int64_t input_0_dim_0 { 256 };
int64_t input_0_dim_1 { 128 };
int64_t input_1_dim_0 { 128 };
int64_t input_1_dim_1 { 32 };
};
} // namespace ait
\ No newline at end of file
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <string>
#include <unordered_map>
#include <math.h>
void gemm_rrr_3(
void *,
void *,
void *,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
hipStream_t
);
#define CHECK_VECTOR_ACCESS(vector, idx) \
if (idx >= vector.size()) { \
throw std::out_of_range( \
"[__func__]: index out of range, " #vector ".size()=" + \
std::to_string(vector.size()) + ", got " + std::to_string(idx)); \
}
namespace ait {
namespace {
void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") + GetLastErrorString() +
" enum: " + std::to_string(device_error) +
" at " + file + ": " + std::to_string(line);
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
thread_local bool target_has_graph_mode = false;
} // namespace
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class Model {
public:
Model(
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
uint8_t* constants,
AITemplateAllocator& allocator)
: blob_(RAII_DeviceMalloc(blob_size, allocator)),
workspace_(RAII_DeviceMalloc(workspace_size, allocator)),
params_(num_inputs + num_outputs + num_unbound_constants),
num_inputs_(num_inputs),
num_outputs_(num_outputs),
constants_(constants) {
dmlc::InitLogging("aitemplate"); // TODO(xxx): render network name
LOG(INFO) << "Init AITemplate Runtime.";
global_workspace_ = static_cast<uint8_t*>(workspace_.get()) + 0;
unique_workspace_ = static_cast<uint8_t*>(workspace_.get());
DEVICE_CHECK(GetDevice(&device_idx_))
DEVICE_CHECK(CreateEvent(&run_finished_));
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
DEVICE_CHECK(cudaDeviceGetAttribute(
&max_smem_size_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx_));
#endif
DEVICE_CHECK(GetDeviceProperties(&device_properties_, device_idx_));
DEVICE_CHECK(StreamCreate(&graph_capture_stream_, /*non_blocking=*/true));
InitConstants(constants_);
auto* blob_ptr = static_cast<uint8_t*>(blob_.get());
}
~Model() {
if (run_finished_ != nullptr) {
DestroyEvent(run_finished_);
}
if (graph_capture_stream_ != nullptr) {
StreamDestroy(graph_capture_stream_);
}
if (graph_exec_ != nullptr) {
GraphExecDestroy(graph_exec_);
}
}
Model(Model&& other) {
run_finished_ = other.run_finished_;
graph_exec_ = other.graph_exec_;
graph_capture_stream_ = other.graph_capture_stream_;
other.run_finished_ = nullptr;
other.graph_exec_ = nullptr;
other.graph_capture_stream_ = nullptr;
constants_ = other.constants_;
num_inputs_ = other.num_inputs_;
global_workspace_ = other.global_workspace_;
unique_workspace_ = other.unique_workspace_;
workspace_ = std::move(other.workspace_);
params_ = std::move(other.params_);
constant_name_to_ptr_ = std::move(other.constant_name_to_ptr_);
// Re-wire the pointers in the above 2 structures.
InitConstants(constants_);
}
Model& operator=(Model&&) = delete;
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
void SetUpInputsOutputs() {
input_0 = static_cast<decltype(input_0)>(params_[0].ptr);
if (input_0 == nullptr) {
throw std::runtime_error("Constant input_0 was not set! Set the value with set_constant.");
}
input_1 = static_cast<decltype(input_1)>(params_[1].ptr);
if (input_1 == nullptr) {
throw std::runtime_error("Constant input_1 was not set! Set the value with set_constant.");
}
output_0 = static_cast<decltype(output_0)>(params_[2].ptr);
if (output_0 == nullptr) {
throw std::runtime_error("Constant output_0 was not set! Set the value with set_constant.");
}
}
void DeviceToDeviceCopies(StreamType stream) {
}
void Run(StreamType stream, bool graph_mode) {
SetUpInputsOutputs();
if (target_has_graph_mode && graph_mode) {
RunAsGraph(stream);
} else {
RunImpl(stream);
}
DEVICE_CHECK(EventRecord(run_finished_, stream));
}
void RunImpl(StreamType stream) {
gemm_rrr_3(
input_0,
input_1,
output_0,
&input_0_dim_0,
&input_0_dim_1,
&input_1_dim_0,
&input_1_dim_1,
&input_0_dim_0,
&input_1_dim_1,
stream
);
DeviceCheckLastError(__FILE__, __LINE__);
DeviceToDeviceCopies(stream);
}
bool IsPending() {
auto query = QueryEvent(run_finished_);
if (query == GetDeviceNotReady()) {
return true;
}
if (query != GetDeviceSuccess()) {
LOG(WARNING) << "Pending model run did not finish successfully. Error: "
<< GetErrorString(query);
}
return false;
}
void WaitForCompletion() {
DEVICE_CHECK(EventSynchronize(run_finished_));
}
size_t NumInputs() const {
return num_inputs_;
}
size_t NumOutputs() const {
return num_outputs_;
}
void SetParam(const void* src, size_t param_idx) {
CHECK_VECTOR_ACCESS(params_, param_idx)
// const_cast is not ideal here, but it is unfortunately
// necessary:
// 1) We store outputs and inputs in the same vector,
// and outputs cannot be const.
// 2) Most of the codegen is not const-correct (most ops
// require non-const pointers). So even if we put const
// pointers into params, a const_cast would be required
// somewhere else.
params_[param_idx].ptr = const_cast<void*>(src);
}
void SetInput(const void* src, const AITemplateParamShape& shape, size_t idx) {
SetInputShape(shape, idx);
SetParam(src, idx);
}
void SetOutput(void* src, size_t idx) {
SetParam(src, idx + num_inputs_);
}
// Write the (possibly dynamic) output shape to the given pointer.
// Note that this should be called _after_ the shape inference in
// Run() is finished. output_shape_out should be able to store
// at least GetOutputMaximumShape(idx).size values.
void GetOutputShape(size_t idx, int64_t* output_shape_out) {
const auto param_idx = idx + num_inputs_;
CHECK_VECTOR_ACCESS(params_, param_idx);
const auto& shape_ptrs = params_[param_idx].shape_ptrs;
for (size_t i = 0; i < shape_ptrs.size(); ++i) {
output_shape_out[i] = shape_ptrs[i].GetValue();
}
}
void SetConstant(const char* name, const void* src) {
auto it = constant_name_to_ptr_.find(name);
if (it == constant_name_to_ptr_.end()) {
throw std::out_of_range(std::string("Could not find constant ") + name);
}
const void** ptr = it->second;
*ptr = src;
}
private:
void InitConstants(uint8_t* constants) {
params_[0].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(128, 128, &input_0_dim_1)};
params_[1].shape_ptrs = {ParamDim(128, 128, &input_1_dim_0), ParamDim(32, 32, &input_1_dim_1)};
params_[2].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(32, 32, &input_1_dim_1)};
}
void SetInputShape(const AITemplateParamShape& shape, size_t idx) {
auto& param = params_[idx];
if (shape.size != param.shape_ptrs.size()) {
throw std::runtime_error(
"[SetInputShape] Got wrong param shape for input " + std::to_string(idx) +
"; expected " + std::to_string(param.shape_ptrs.size()) + ", got " +
std::to_string(shape.size));
}
for (size_t i = 0; i < param.shape_ptrs.size(); ++i) {
param.shape_ptrs[i].SetValue(shape.shape_data[i]);
}
}
DeviceError EndCapture(GraphType* graph_ptr) {
auto err = StreamEndCapture(graph_capture_stream_, graph_ptr);
if (err != GetDeviceSuccess()) {
// If we can't take the stream out of capture mode, something is probably
// wrong with CUDA graph for this model (e.g. there might have been an
// illegal capture mode operation). Disable graph mode to avoid such issues
// in future iterations.
target_has_graph_mode = false;
LOG(WARNING) << "Graph capture failed to end. Disabling graph mode.";
return err;
}
return GetDeviceSuccess();
}
void RunAsGraph(StreamType stream) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING) << "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr){ return EndCapture(graph_ptr); }
);
if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (GraphExecUpdate(graph_exec_, graph.get()) != GetDeviceSuccess()) {
// Consume the last cuda error, which may affect the next GraphExecLaunch
// call.
GetLastError();
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
int device_idx_;
int max_smem_size_{0};
DevicePropertyType device_properties_;
// This event tracks when the inference is finished
// so that this Model may be reclaimed by its owning
// ModelContainer.
EventType run_finished_;
// A blob of memory used for storing intermediate tensors.
GPUPtr blob_;
// Memory for constants that were folded into the *.so. Unowned by Model,
// owned by ModelContainer.
// TODO: make this const. It can't be const right now because we derive
// tensor pointers from it, and no tensor pointers are const.
uint8_t* constants_;
size_t num_inputs_;
size_t num_outputs_;
// The workspace blob is used as scratch memory. See
// _generate_workspace in memory planning for more information.
GPUPtr workspace_;
uint8_t* global_workspace_{nullptr};
uint8_t* unique_workspace_{nullptr};
class ParamDim {
public:
ParamDim(int64_t lower_bound, int64_t upper_bound, int64_t* value) :
lower_bound_(lower_bound),
upper_bound_(upper_bound),
value_(value) {}
void SetValue(int64_t new_value) {
if (new_value < lower_bound_ || new_value > upper_bound_) {
throw std::out_of_range(
"[SetValue] Dimension got value out of bounds; expected value to be in [" +
std::to_string(lower_bound_) + ", " + std::to_string(upper_bound_) + "], but got " +
std::to_string(new_value)
);
}
*value_ = new_value;
}
int64_t GetValue() const {
return *value_;
}
private:
int64_t lower_bound_;
int64_t upper_bound_;
int64_t* value_;
};
struct ParamInfo {
void* ptr = nullptr;
// TODO add offset
const char* name;
std::vector<ParamDim> shape_ptrs;
};
// Contains info for all tensors marked as inputs
// or outputs. The first num_inputs elements are the inputs.
// Constants are not included.
std::vector<ParamInfo> params_;
GraphExecType graph_exec_ = nullptr;
StreamType graph_capture_stream_;
std::unordered_map<std::string, const void**> constant_name_to_ptr_;
void * input_0 {nullptr};
void * input_1 {nullptr};
void * output_0 {nullptr};
int64_t input_0_dim_0 { 256 };
int64_t input_0_dim_1 { 128 };
int64_t input_1_dim_0 { 128 };
int64_t input_1_dim_1 { 32 };
};
} // namespace ait
\ No newline at end of file
#include "model_container.h"
#include "device_functions-generated.h"
#include "raii_wrapper.h"
namespace ait {
ModelContainer::ModelContainer(
size_t num_models,
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
size_t params_size,
AITemplateAllocator& allocator)
: ModelContainerBase(
num_inputs,
num_outputs,
num_unbound_constants,
params_size,
allocator),
allocator_(allocator),
num_inputs_(num_inputs),
num_outputs_(num_outputs) {
if (num_models == 0) {
throw std::runtime_error("Number of models must be positive");
}
models_.reserve(num_models);
available_models_.reserve(num_models);
for (size_t i = 0; i < num_models; ++i) {
models_.emplace_back(
blob_size,
workspace_size,
num_inputs,
num_outputs,
num_unbound_constants,
static_cast<uint8_t*>(constants_.get()),
allocator);
available_models_.push_back(&models_.back());
}
}
void ModelContainer::Run(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool sync,
bool graph_mode,
int64_t** output_shapes_out) {
auto* model = GetAvailableModel();
try {
PrepareForRun(model, inputs, num_inputs, outputs, num_outputs);
model->Run(stream, graph_mode);
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);
throw;
}
if (output_shapes_out) {
for (size_t i = 0; i < num_outputs; ++i) {
auto* out_shape = output_shapes_out[i];
model->GetOutputShape(i, out_shape);
}
}
{
std::lock_guard lk(models_mutex_);
pending_models_.push_back(model);
}
pending_models_available_.notify_one();
if (sync) {
StreamSynchronize(stream);
}
}
void ModelContainer::RunWithOutputsOnHost(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
int64_t** output_shapes_out) {
std::vector<std::pair<GPUPtr, size_t>> owned_outputs_ptrs;
std::vector<AITData> owned_outputs;
owned_outputs_ptrs.reserve(num_outputs);
owned_outputs.reserve(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
size_t num_bytes = MaxOutputStorageBytes(i);
owned_outputs_ptrs.emplace_back(
RAII_DeviceMalloc(num_bytes, allocator_), num_bytes);
owned_outputs.emplace_back(
owned_outputs_ptrs.back().first.get(),
outputs[i].shape,
outputs[i].dtype);
}
Run(inputs,
num_inputs,
owned_outputs.data(),
num_outputs,
stream,
/*sync=*/false,
graph_mode,
output_shapes_out);
for (size_t i = 0; i < num_outputs; ++i) {
auto& owned_output = owned_outputs_ptrs[i];
auto& ptr = owned_output.first;
auto num_bytes = owned_output.second;
DEVICE_CHECK(CopyToHost(outputs[i].ptr, ptr.get(), num_bytes, stream));
}
DEVICE_CHECK(StreamSynchronize(stream));
}
float ModelContainer::Benchmark(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
size_t count,
size_t num_threads,
bool use_unique_stream_per_thread,
int64_t** output_shapes_out) {
if (num_threads == 0) {
num_threads = std::thread::hardware_concurrency();
}
if (num_threads == 1) {
return BenchmarkImpl(
inputs,
num_inputs,
outputs,
num_outputs,
stream,
graph_mode,
count,
output_shapes_out) /
count;
}
// Clone the outputs, each thread needs its own set
std::vector<std::vector<GPUPtr>> per_thread_outputs_ptrs;
std::vector<std::vector<AITData>> per_thread_outputs;
std::vector<StreamPtr> per_thread_streams;
per_thread_outputs_ptrs.reserve(num_threads - 1);
per_thread_outputs.reserve(num_threads - 1);
if (use_unique_stream_per_thread) {
per_thread_streams.reserve(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
per_thread_streams.push_back(RAII_StreamCreate(/*non_blocking=*/true));
}
}
for (size_t i = 1; i < num_threads; ++i) {
std::vector<GPUPtr> cloned_outputs_ptrs;
std::vector<AITData> cloned_outputs;
cloned_outputs_ptrs.reserve(num_outputs);
cloned_outputs.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t num_bytes = MaxOutputStorageBytes(j);
cloned_outputs_ptrs.emplace_back(
RAII_DeviceMalloc(num_bytes, allocator_));
auto* new_pointer = cloned_outputs_ptrs.back().get();
DEVICE_CHECK(
DeviceToDeviceCopy(new_pointer, outputs[j].ptr, num_bytes, stream));
cloned_outputs.emplace_back(
new_pointer, outputs[j].shape, outputs[j].dtype);
}
per_thread_outputs_ptrs.push_back(std::move(cloned_outputs_ptrs));
per_thread_outputs.push_back(std::move(cloned_outputs));
}
DEVICE_CHECK(StreamSynchronize(stream));
auto get_stream = [stream, use_unique_stream_per_thread, &per_thread_streams](
size_t thread_idx) {
if (!use_unique_stream_per_thread) {
return stream;
}
return per_thread_streams[thread_idx].get();
};
auto thread_func = [&](size_t thread_idx) {
AITData* thread_outputs =
thread_idx == 0 ? outputs : per_thread_outputs[thread_idx - 1].data();
StreamType thread_stream = get_stream(thread_idx);
auto* thread_output_shapes_out =
thread_idx == 0 ? output_shapes_out : nullptr;
return BenchmarkImpl(
inputs,
num_inputs,
thread_outputs,
num_outputs,
thread_stream,
graph_mode,
count,
thread_output_shapes_out);
};
std::vector<std::future<float>> futures;
futures.reserve(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
futures.push_back(std::async(std::launch::async, thread_func, i));
}
auto max_time = std::accumulate(
futures.begin(), futures.end(), 0.f, [](float cur_val, auto& future) {
return std::max(future.get(), cur_val);
});
// Verify that all the outputs are the same
for (size_t i = 0; i < num_outputs; ++i) {
auto output_size = MaxOutputStorageBytes(i);
auto output_host = std::make_unique<uint8_t[]>(output_size);
// NB: technically, we don't have to copy to host here, but using
// std::memcmp is easier than writing a kernel that does comparisons
// for both backends, and performance is not important here.
DEVICE_CHECK(
CopyToHost(output_host.get(), outputs[i].ptr, output_size, stream));
DEVICE_CHECK(StreamSynchronize(stream));
for (size_t thread_idx = 1; thread_idx < num_threads; ++thread_idx) {
auto* thread_output = per_thread_outputs[thread_idx - 1][i].ptr;
auto thread_output_host = std::make_unique<uint8_t[]>(output_size);
auto thread_stream = get_stream(thread_idx);
DEVICE_CHECK(CopyToHost(
thread_output_host.get(), thread_output, output_size, thread_stream));
DEVICE_CHECK(StreamSynchronize(thread_stream));
if (std::memcmp(
output_host.get(), thread_output_host.get(), output_size)) {
throw std::runtime_error(
"Output " + std::to_string(i) +
" did not match for a spawned thread!");
}
}
}
auto total_num_iters = num_threads * count;
return max_time / total_num_iters;
}
void ModelContainer::SetConstant(const char* name, const AITData& tensor) {
auto it = unbound_constant_name_to_idx_.find(name);
if (it == unbound_constant_name_to_idx_.end()) {
// TODO make this an exception after we fix the CMF benchmarks
LOG(ERROR) << "Constant " << name << " not found";
return;
}
auto constant_idx = it->second + num_inputs_ + num_outputs_;
ValidateDtype(tensor.dtype, constant_idx);
CHECK_VECTOR_ACCESS(max_param_storage_bytes_, constant_idx)
auto expected_num_bytes = max_param_storage_bytes_[constant_idx];
auto actual_num_bytes =
tensor.shape.Numel() * AITemplateDtypeSizeBytes(tensor.dtype);
if (expected_num_bytes != actual_num_bytes) {
throw std::runtime_error(
std::string(
"SetConstant did not recieve correct number of bytes for constant ") +
name + ": expected " + std::to_string(expected_num_bytes) +
" but got " + std::to_string(actual_num_bytes) +
". Check that the provided tensor's shape is correct.");
}
auto* src = tensor.ptr;
for (auto& model : models_) {
model.SetConstant(name, src);
}
}
size_t ModelContainer::NumInputs() const {
return num_inputs_;
}
const char* ModelContainer::InputName(size_t input_idx) const {
CHECK_VECTOR_ACCESS(param_names_, input_idx)
return param_names_[input_idx];
}
size_t ModelContainer::NumOutputs() const {
return num_outputs_;
}
const char* ModelContainer::OutputName(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(param_names_, idx)
return param_names_[idx];
}
AITemplateParamShape ModelContainer::MaxOutputShape(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(max_param_shapes_, idx)
auto& out_shape = max_param_shapes_[idx];
return AITemplateParamShape{out_shape.data(), out_shape.size()};
}
AITemplateDtype ModelContainer::OutputDtype(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(param_dtypes_, idx)
return param_dtypes_[idx];
}
size_t ModelContainer::MaxOutputStorageBytes(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(max_param_storage_bytes_, idx)
return max_param_storage_bytes_[idx];
}
void ModelContainer::PrepareForRun(
Model* model,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs) {
if (num_inputs != num_inputs_) {
auto msg = "Got wrong number of inputs; expected " +
std::to_string(num_inputs_) + ", got " + std::to_string(num_inputs);
throw std::runtime_error(std::move(msg));
}
if (num_inputs > 0 && inputs == nullptr) {
throw std::runtime_error("inputs cannot be null");
}
if (num_outputs != num_outputs_) {
auto msg = "Got wrong number of outputs; expected " +
std::to_string(num_outputs_) + ", got " + std::to_string(num_outputs);
throw std::runtime_error(std::move(msg));
}
if (num_outputs > 0 && outputs == nullptr) {
throw std::runtime_error("outputs cannot be null");
}
for (size_t i = 0; i < num_inputs_; ++i) {
auto& input = inputs[i];
ValidateDtype(input.dtype, i);
model->SetInput(input.ptr, input.shape, i);
}
for (size_t i = 0; i < num_outputs_; ++i) {
auto& output = outputs[i];
ValidateDtype(output.dtype, i + num_inputs_);
model->SetOutput(output.ptr, i);
}
}
Model* ModelContainer::GetAvailableModel() {
std::unique_lock lk(models_mutex_);
if (available_models_.empty()) {
ReclaimFinishedModels(lk);
}
auto* result = available_models_.back();
available_models_.pop_back();
return result;
}
void ModelContainer::ReclaimFinishedModels(std::unique_lock<std::mutex>& lk) {
// Put any complete models at the end
auto it = std::stable_partition(
pending_models_.begin(), pending_models_.end(), [](Model* m) {
return m->IsPending();
});
if (it != pending_models_.end()) {
// Move all available models to the pool.
available_models_.insert(
available_models_.end(), it, pending_models_.end());
pending_models_.erase(it, pending_models_.end());
return;
}
pending_models_available_.wait(
lk, [this]() { return !pending_models_.empty(); });
// There are no available workspaces! We have to wait on one.
auto* model = pending_models_.front();
pending_models_.pop_front();
lk.unlock();
try {
model->WaitForCompletion();
} catch (...) {
lk.lock();
available_models_.push_back(model);
throw;
}
lk.lock();
available_models_.push_back(model);
}
void ModelContainer::ValidateDtype(AITemplateDtype dtype, size_t idx) const {
CHECK_VECTOR_ACCESS(param_dtypes_, idx)
if (dtype != param_dtypes_[idx]) {
auto GetEnumString = [](auto dtype) {
switch (dtype) {
case AITemplateDtype::kUnset:
return "kUnset";
case AITemplateDtype::kHalf:
return "kHalf";
case AITemplateDtype::kFloat:
return "kFloat";
case AITemplateDtype::kInt:
return "kInt";
case AITemplateDtype::kLong:
return "kLong";
default:
return "unknown";
}
};
throw std::runtime_error(
"Got wrong dtype for param " + std::to_string(idx) + "; expected " +
GetEnumString(param_dtypes_[idx]) + ", got " + GetEnumString(dtype));
}
}
float ModelContainer::BenchmarkImpl(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
size_t count,
int64_t** output_shapes_out) {
auto* model = GetAvailableModel();
float runtime_ms = 0.;
auto start_event = RAII_CreateEvent();
auto end_event = RAII_CreateEvent();
try {
PrepareForRun(model, inputs, num_inputs, outputs, num_outputs);
DEVICE_CHECK(EventRecord(start_event.get(), stream));
for (size_t i = 0; i < count; ++i) {
model->Run(stream, graph_mode);
}
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);
throw;
}
if (output_shapes_out) {
for (size_t i = 0; i < num_outputs; ++i) {
auto* out_shape = output_shapes_out[i];
model->GetOutputShape(i, out_shape);
}
}
// Push the model back into the pool before synchronizing the event
// to exercise the concurrency code
{
std::lock_guard lk(models_mutex_);
pending_models_.push_back(model);
}
pending_models_available_.notify_one();
DEVICE_CHECK(EventRecord(end_event.get(), stream));
DEVICE_CHECK(EventSynchronize(end_event.get()));
DEVICE_CHECK(
EventElapsedTime(&runtime_ms, start_event.get(), end_event.get()));
LOG(INFO) << "Benchmark runtime ms/iter: " << runtime_ms / count;
return runtime_ms;
}
} // namespace ait
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment