"vscode:/vscode.git/clone" did not exist on "d6a2cfb5f33ce800f9c57ea79d8ef4169bc86ca3"
Commit 47cc9b7e authored by Astha Rai's avatar Astha Rai
Browse files

added compilation of shared library and multiple instances for gemm, cleaned up code design

parent adbefd90
# CK Python API
This API uses Python to generate instances of operations present in CK, compiles them into a shared library, and an executable to run the instances.
There are 2 directories: shared and normal. The normal directory contains one instance that will compile into an excutable to be run, while the shared directory
generates multiple instances and compiles them into a shared library.
## Normal
## Shared
gemm: xx.o
CFLAGS=-I ~/workspace/composable_kernel/include -I /opt/workspace/rocm-5.1.1/hip/include -I ~/workspace/composable_kernel/include/ -I ~/workspace/composable_kernel/include/ck/ -I ~/workspace/composable_kernel/include/ck/problem_transform/ -I ~/workspace/composable_kernel/include/ck/tensor/ -I ~/workspace/composable_kernel/include/ck/tensor_description/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/workspace/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/workspace/composable_kernel/library/include/ck/library/host/ -I ~/workspace/composable_kernel/library/include/ck/library/host_tensor/ -I ~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_op/ -I ~/workspace/composable_kernel/library/include/ck/library/utility/ -I ~/workspace/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
xx.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c xx.cpp
CC = {{cc}}
CFLAGS = {{CFLAGS}}
fPIC_flag = {{fPIC}}
obj_files = {{obj_files}}
%.obj : %.{{cpp}}
{{cfile_cmd}}
%.obj : %.bin
{{bfile_cmd}}
.PHONY: all clean clean_constants
all: {{target}}
{{target}}: $(obj_files)
$(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files)
clean:
rm -f *.obj {{target}} test.so
clean_constants:
rm -f constants.bin
\ No newline at end of file
import enum
import os.path
import shutil
import functools
import operator
import collections
import subprocess
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.make_template = """
CFLAGS=-I ~/workspace/composable_kernel/include -I /opt/workspace/rocm-5.1.1/hip/include -I ~/workspace/composable_kernel/include/ -I ~/workspace/composable_kernel/include/ck/ -I ~/workspace/composable_kernel/example/01_gemm/ -I ~/workspace/composable_kernel/library/include/ -I ~/workspace/composable_kernel/library/src/utility/ -I ~/workspace/composable_kernel/include/ck/problem_transform/ -I ~/workspace/composable_kernel/include/ck/tensor/ -I ~/workspace/composable_kernel/include/ck/tensor_description/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/block/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/device/impl/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/element/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/grid/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/thread/ -I ~/workspace/composable_kernel/include/ck/tensor_operation/gpu/warp/ -I ~/workspace/composable_kernel/include/ck/host_utility -I /external/include/half/ -I ~/workspace/composable_kernel/library/include/ck/library/host/ -I ~/workspace/composable_kernel/library/include/ck/library/host_tensor/ -I ~/workspace/composable_kernel/library/include/ck/library/obselete_driver_offline/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/ -I ~/workspace/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/" + "reduce/ -I ~/workspace/composable_kernel/library/include/ck/library/tensor_op/ -I ~/workspace/composable_kernel/library/include/ck/library/utility/ -I ~/workspace/composable_kernel/profiler/include/
CXXFLAGS = -std=c++17
gemm: ex.o host_tensor.o device_memory.o
hipcc $(CXXFLAGS) $(CFLAGS) ex.o host_tensor.o device_memory.o -o gemm
device_memory.o: ../../../../library/src/utility/device_memory.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/device_memory.cpp
host_tensor.o: ../../../../library/src/utility/host_tensor.cpp
hipcc $(CXXFLAGS) $(CFLAGS) -c ../../../../library/src/utility/host_tensor.cpp
ex.o:
hipcc -fPIC -fvisibility=hidden $(CXXFLAGS) -w /opt/rocm-5.3.0/amdgcn/bitcode/oclc_abi_version_400.bc $(CFLAGS) -L/opt/rocm-5.3.0/rocrand -lrocrand -x hip -c ex.cpp
"""
self.gemm_devop_template = """
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl<
${type_a},
${type_b},
${type_c},
${type_acc},
${layout_a},
${layout_b},
${layout_c},
${elementwise_op_a},
${elementwise_op_b},
${elementwise_op_c},
${Gemm_spec},
${block_size},
${mperblock},
${nperblock},
${k0perblock},
${k1},
${m1perthread},
${n1perthread},
${kperthread},
${m1n1_thcluster_m1xs},
${m1n1_thcluster_n1xs},
${ABT_thread_slice_lengths_K0_M0_M1_K1},
${ABT_thread_cluster_lengths_K0_M0_M1_K1},
${ABT_thread_cluster_arrange_order},
${ABT_src_access_order},
${ABT_src_vec_tensor_lengths_K0_M0_M1_K1},
${ABT_src_vec_tensor_cont_dim_order},
${ABT_dst_vec_tensor_lengths_K0_M0_M1_K1},
${BBT_thread_slice_lengths_K0_N0_N1_K1},
${BBT_thread_cluster_lengths_K0_N0_N1_K1},
${BBT_thread_cluster_arrange_order},
${BBT_src_access_order},
${BBT_src_vec_tensor_lengths_K0_N0_N1_K1},
${BBT_src_vec_tensor_cont_dim_order},
${BBT_dst_vec_tensor_lengths_K0_N0_N1_K1},
${CTT_src_dst_access_order},
${CTT_src_dst_vec_dim},
${CTT_dst_scalar_per_vector}>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<${type_a}> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ${layout_a}{}));
Tensor<${type_b}> b_k_n(f_host_tensor_descriptor(K, N, StrideB, ${layout_b}{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<${type_a}>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<${type_b}>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<${type_a}>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<${type_b}>{-1.f, 1.f}(b_k_n);
}
Tensor<${type_c}> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<${type_c}> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(${type_a}) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(${type_b}) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(${type_c}) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = ${elementwise_op_a}{};
auto b_element_op = ${elementwise_op_b}{};
auto c_element_op = ${elementwise_op_c}{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<${type_a}*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<${type_b}*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<${type_c}*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(${type_a}) * M * K + sizeof(${type_b}) * K * N + sizeof(${type_c}) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
"""
def emit(self):
values = {
'type_a' : 'ck::half_t',
'type_b' : 'ck::half_t',
'type_c' : 'ck::half_t',
'type_acc' : 'float',
'layout_a' : 'ck::tensor_layout::gemm::ColumnMajor',
'layout_b' : 'ck::tensor_layout::gemm::RowMajor',
'layout_c' : 'ck::tensor_layout::gemm::RowMajor',
'elementwise_op_a' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_b' : 'ck::tensor_operation::element_wise::PassThrough',
'elementwise_op_c' : 'ck::tensor_operation::element_wise::PassThrough',
'Gemm_spec' : 'ck::tensor_operation::device::GemmSpecialization::Default',
'block_size' : '256',
'mperblock' : '128',
'nperblock' : '128',
'k0perblock' : '16',
'k1' : '2',
'm1perthread' : '4',
'n1perthread' : '4',
'kperthread' : '1',
'm1n1_thcluster_m1xs' : 'S<8, 2>',
'm1n1_thcluster_n1xs' : 'S<8, 2>',
'ABT_thread_slice_lengths_K0_M0_M1_K1' : 'S<2, 1, 4, 2>',
'ABT_thread_cluster_lengths_K0_M0_M1_K1' : 'S<8, 1, 32, 1>',
'ABT_thread_cluster_arrange_order' : 'S<0, 3, 1, 2>',
'ABT_src_access_order' : 'S<0, 3, 1, 2>',
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1' : 'S<1, 1, 4, 1>',
'ABT_src_vec_tensor_cont_dim_order' : 'S<0, 3, 1, 2>',
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1' : 'S<1, 1, 4, 2>',
'BBT_thread_slice_lengths_K0_N0_N1_K1' : 'S<2, 1, 4, 2>',
'BBT_thread_cluster_lengths_K0_N0_N1_K1' : 'S<8, 1, 32, 1>',
'BBT_thread_cluster_arrange_order' : 'S<0, 3, 1, 2>',
'BBT_src_access_order' : 'S<0, 3, 1, 2>',
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1' : 'S<1, 1, 4, 1>',
'BBT_src_vec_tensor_cont_dim_order' : 'S<0, 3, 1, 2>',
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1': 'S<1, 1, 4, 2>',
'CTT_src_dst_access_order' : 'S<0, 1, 2, 3, 4, 5>',
'CTT_src_dst_vec_dim' : '5',
'CTT_dst_scalar_per_vector' : '4'
}
template = self.gemm_devop_template
cf = open("ex.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
m_template = self.make_template
cf = open("Makefile", 'w')
print(SubstituteTemplate(m_template, values))
cf.write(SubstituteTemplate(m_template, values))
cf.close()
PIPE = -1
STDOUT = -2
proc = subprocess.Popen(
["make"],
shell=True,
env=os.environ.copy(),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
out, err = proc.communicate()
a = EmitGemmInstance()
a.emit()
import enum
import os.path
import shutil
import functools
import operator
import collections
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_devop_template = """
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256,
128,
128,
16,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
"""
def emit(self):
values = {
'type_a' : 'ck::half_t',
}
template = self.gemm_devop_template
cf = open("xx.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
a = EmitGemmInstance()
a.emit()
CC = /opt/rocm/bin/hipcc
CK_PATH=/dockerx/composable_kernel/
CFLAGS = -O3 -std=c++17 -DCK_AMD_GPU_GFX90A --offload-arch=gfx90a -I"${CK_PATH}/include" -I"${CK_PATH}/library/include" -I"${CK_PATH}/profiler/include"
OBJS = ex.o host_tensor.o device_memory.o
all: $(OBJS)
$(CC) $(CFLAGS) $(OBJS) -o ex
device_memory.o: ../../library/src/utility/device_memory.cpp
$(CC) $(CFLAGS) -c ../../library/src/utility/device_memory.cpp
host_tensor.o: ../../library/src/utility/host_tensor.cpp
$(CC) $(CFLAGS) -c ../../library/src/utility/host_tensor.cpp
\ No newline at end of file
import enum
import os.path
import shutil
import functools
import operator
import collections
import re
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
class EmitGemmInstance:
def __init__(self):
self.gemm_devop_template = """
#pragma once
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
256,
128,
128,
16,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<2, 1, 4, 2>,
S<8, 1, 32, 1>,
S<0, 3, 1, 2>,
S<0, 3, 1, 2>,
S<1, 1, 4, 1>,
S<0, 3, 1, 2>,
S<1, 1, 4, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>;
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
return true;
}
bool run_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
"""
def emit(self):
values = {
'type_a' : 'ck::half_t',
}
template = self.gemm_devop_template
cf = open("xx.cpp", 'w')
print(SubstituteTemplate(template, values))
cf.write(SubstituteTemplate(template, values))
cf.close()
a = EmitGemmInstance()
a.emit()
#take in input for gemm from user, send it to example template
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDl : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused, but may be useful in future.
index_t M01_;
index_t N01_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmDl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11,
block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
typename M11N11ThreadClusterM110Xs,
typename M11N11ThreadClusterN110Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemmDl_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_grid_desc_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_grid_desc_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0),
a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0),
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Backend-agnostic functions for elementwise codegen.
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import jinja2
# from aitemplate.backend.backend_spec import BackendSpec
# from ...compiler.base import IntImm, IntVar, Operator, Tensor
# from ...compiler.tensor_accessor import TensorAccessor
# from ...utils import shape_utils
#from . import tensor_accessor_codegen
CONSTANT_TEMPLATE = jinja2.Template(
"""
#define FUSED_ELE_THREAD_SIZE 256
const int N_ELEMENTS_PER_THREAD = sizeof({{read_t}}) / sizeof({{data_t}});
const int N_ELEMENTS_PER_READ = sizeof({{read_t}}) / sizeof({{data_t}});
const int N_OPS_PER_THREAD = sizeof({{read_t}}) / sizeof({{op_t}});
"""
)
KERNEL_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const {{read_t}}* input{{idx}}")
KERNEL_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("{{read_t}}* output{{idx}}")
KERNEL_TMP_INPUT_TEMPLATE = jinja2.Template("p_tmp_i{{idx}}[i]")
KERNEL_TMP_OUTPUT_TEMPLATE = jinja2.Template("p_tmp_o{{idx}}[i]")
GET_STRIDED_ADDRESS_TEMPLATE = jinja2.Template(
"""
{% if tensor_accessor.is_contiguous %}
{{data_ptr}} = get_strided_address</*data_t*/ {{data_t}},
/*read_t*/ {{read_t}},
/*is_contiguous*/ true>(
{{data_ptr}}, {{data_idx}}, {{tensor_accessor.offset}}, 0, 0);
{% else %}
{{data_ptr}} = get_strided_address</*data_t*/ {{data_t}},
/*read_t*/ {{read_t}},
/*is_contiguous*/ false>(
{{data_ptr}}, {{data_idx}},
{{tensor_accessor.offset}},
{{tensor_accessor.original_total_elements_from_stride_dim}},
{{tensor_accessor.actual_total_elements_from_stride_dim}});
{% endif %}
"""
)
KERNEL_READ_INPUT_TEMPLATE = jinja2.Template(
"""
{{read_t}} *{{input_name}} = const_cast<{{read_t}}*>(input{{input_idx}});
{{get_strided_address}}
{{read_t}} tmp_i{{input_idx}} = *{{input_name}};
const {{op_t}}* p_tmp_i{{input_idx}} = reinterpret_cast<const {{op_t}}*>(&tmp_i{{input_idx}});
"""
)
KERNEL_DEFINE_OUTPUTS_TEMPLATE = jinja2.Template(
"""
{% for idx in indexes %}
{{read_t}} tmp_o{{idx}};
{{op_t}}* p_tmp_o{{idx}} = reinterpret_cast<{{op_t}}*>(&tmp_o{{idx}});
{% endfor %}
"""
)
KERNEL_WRITE_OUTPUT_TEMPLATE = jinja2.Template(
"""
{{get_strided_address}}
*{{output_name}} = tmp_o{{output_idx}};
"""
)
KERNEL_TEMPLATE = jinja2.Template(
"""
__global__ void
{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const {{index_type}} idx = bid * FUSED_ELE_THREAD_SIZE + tid;
const {{index_type}} idx_elem = idx * N_ELEMENTS_PER_THREAD;
if (idx_elem >= n_elements) {
return;
}
{{read_inputs}}
{{define_outputs}}
#pragma unroll
for (int i = 0; i < N_OPS_PER_THREAD; ++i) {
{{fused_funcs}}
}
{{write_outputs}}
}
"""
)
FUNC_DECL_INPUT_PARAM_TEMPLATE = jinja2.Template("const void* input{{idx}}")
FUNC_DECL_OUTPUT_PARAM_TEMPLATE = jinja2.Template("void* output{{idx}}")
KERNEL_CALL_INPUT_PARAM_TEMPLATE = jinja2.Template(
"reinterpret_cast<const {{read_t}}*>(input{{idx}})"
)
KERNEL_CALL_OUTPUT_PARAM_TEMPLATE = jinja2.Template(
"reinterpret_cast<{{read_t}}*>(output{{idx}})"
)
FUNC_TEMPLATE = jinja2.Template(
"""
{{head}}
namespace {
{{constant}}
{{custom_libs}}
{{tensor_accessor_lib}}
{{kernel_function}}
} // namespace
void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims_decl}} {{index_type}} n_elements, {{prefix}}Stream_t stream) {
if (n_elements == 0) {
return;
}
int block_size = static_cast<int>(std::ceil(static_cast<double>(n_elements) / N_ELEMENTS_PER_THREAD / FUSED_ELE_THREAD_SIZE));
{{func_name}}<<<block_size, FUSED_ELE_THREAD_SIZE, 0, stream>>>(
{{kernel_call_output_params}},
{{kernel_call_input_params}},
{{dynamic_dims_call}}
n_elements
);
}
"""
)
FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements, {{prefix}}Stream_t stream);
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{
{{indent}}{{index_type}} {{func_name}}_n_elements = {{calculate_n}};
{{indent}}invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{func_name}}_n_elements, {{stream}});
{{indent}}}
"""
)
@dataclass
class ElementwiseMetaData:
func_name: str
op_t: str
args: List[Tensor]
outputs: List[Tensor]
@dataclass
class FusedElementwiseMetaData:
# Input / output Tensors and TensorAccessors.
inputs: List[Tensor]
outputs: List[Tensor]
input_accessors: List[TensorAccessor]
output_accessors: List[TensorAccessor]
# Original input / output Tensors before graph transformation.
# Kept here for elementwise -> fused elementwise Tensor mapping.
original_inputs: List[Tensor]
original_outputs: List[Tensor]
read_t: str
op_t: str
data_t: str
input_broadcast_sizes: List[List[IntVar]]
dynamic_dims: List[IntVar]
sub_funcs: List[ElementwiseMetaData]
def gen_function_single_thread(
fused_func_metadata,
input_names,
output_names,
type_converter,
) -> str:
"""Per thread elementwise function codegen."""
tensor_to_expr: Dict[Tensor, str] = {}
body = ""
for tensor, name in zip(fused_func_metadata.original_inputs, input_names):
tensor_to_expr[tensor] = name
tmp_output_idx: int = 0
for func_metadata in fused_func_metadata.sub_funcs:
params: List[str] = []
func_op_t = func_metadata.op_t
input_converter = None
output_converter = None
if func_op_t != fused_func_metadata.op_t:
input_converter = type_converter.get(fused_func_metadata.op_t).get(
func_op_t
)
output_converter = type_converter.get(func_op_t).get(
fused_func_metadata.op_t
)
assert (
input_converter is not None
), "Unsupported convertion from {} to {}".format(
fused_func_metadata.op_t, func_op_t
)
assert (
output_converter is not None
), "Unsupported convertion from {} to {}".format(
func_op_t, fused_func_metadata.op_t
)
for arg in func_metadata.args:
if arg in tensor_to_expr:
param = tensor_to_expr[arg]
params.append(
"{}({})".format(input_converter, param)
if input_converter is not None
else param
)
elif arg.is_a_const_num():
if func_op_t[-1] == "2":
params.append(
"{}({},{})".format(
func_op_t,
str(arg._attrs["value"]),
str(arg._attrs["value"]),
)
)
else:
params.append("{}({})".format(func_op_t, str(arg._attrs["value"])))
else:
raise RuntimeError(
"Cannot generate expression for node {}, ops: {}".format(
arg, func_metadata
)
)
assert (
len(func_metadata.outputs) == 1
), "Operator has more than 1 output! Operator: {}".format(func_metadata)
output = func_metadata.outputs[0]
func_def = "{}({})".format(func_metadata.func_name, ",".join(params))
func_def = (
"{}({})".format(output_converter, func_def)
if output_converter is not None
else func_def
)
if len(output._attrs["dst_ops"]) > 1:
name = "tmp_" + (str)(tmp_output_idx)
tmp_output_idx += 1
body += "{} {} = {};\n".format(fused_func_metadata.op_t, name, func_def)
tensor_to_expr[output] = name
else:
tensor_to_expr[output] = func_def
for tensor, name in zip(fused_func_metadata.original_outputs, output_names):
if tensor not in tensor_to_expr:
raise RuntimeError(
"Cannot generate expression for node {}, outputs: {}".format(
tensor, fused_func_metadata.original_outputs
)
)
expr = tensor_to_expr[tensor]
body += "{} = {};\n".format(name, expr)
return body
def _get_sub_func_metadata(
ops: List[Operator], data_t: str, op_t: str, backend_spec: BackendSpec
) -> Tuple[List[ElementwiseMetaData], str]:
candidate_op_types = backend_spec.get_candidate_op_types(op_t)
func_enums = []
for op in ops:
func_enum = op._attrs["func"]
func_enums.append(func_enum)
funcs = backend_spec.func_enum_to_func_name.get(func_enum)
if funcs is None:
raise NotImplementedError("Func {} is not supported!".format(func_enum))
for candidate_op_t in candidate_op_types:
func_name = funcs.get(candidate_op_t)
if func_name is not None:
candidate_op_types = backend_spec.get_candidate_op_types(candidate_op_t)
break
if len(candidate_op_types) == 0:
raise RuntimeError(
"Cannot find a common rocm data type! candidate_op_types: {}, op_t: {}.".format(
candidate_op_types, op_t
)
)
if op_t in set(candidate_op_types):
op_t = candidate_op_types[0]
else:
op_t = data_t
candidate_op_types = backend_spec.get_candidate_op_types(op_t)
sub_func_metadata = []
for op in ops:
func_enum = op._attrs["func"]
funcs = backend_spec.func_enum_to_func_name.get(func_enum)
func_name = None
func_op_t = None
for candidate_op_t in candidate_op_types:
func_name = funcs.get(candidate_op_t)
if func_name is not None:
func_op_t = candidate_op_t
break
if func_name is None:
raise NotImplementedError(
"Unsupported func {} and op type {}!".format(func_enum, op_t)
)
sub_func_metadata.append(
ElementwiseMetaData(
func_name, func_op_t, op._attrs["args"], op._attrs["outputs"]
)
)
return (sub_func_metadata, op_t)
def _get_types_and_sizes(
inputs: List[Tensor],
input_accessors: List[TensorAccessor],
output_accessors: List[TensorAccessor],
backend_spec: BackendSpec,
) -> Tuple[int, List[List[IntVar]], str]:
"""
Returns Tuple(alignment, input_broadcast_sizes, dtype)
"""
# Handle input broadcast.
output_shape = output_accessors[0].original_shapes
dtype = inputs[0]._attrs["dtype"]
input_broadcast_sizes = []
min_num_elements = None
for input_accessor in input_accessors:
input_shape = input_accessor.original_shapes
broadcastable, _ = shape_utils.get_broadcast_max_shape(
output_shape, input_shape
)
if not broadcastable:
raise RuntimeError(
"Input shape {} is not compatible with output shape {}!".format(
input_shape, output_shape
)
)
num_rightmost_non_broadcast_elements = len(input_shape)
extended_input_shape = list(input_shape)
if input_shape == output_shape:
input_broadcast_sizes.append(None)
else:
extended_input_shape = [IntImm(1)] * len(output_shape)
extended_input_shape[len(output_shape) - len(input_shape) :] = input_shape
input_broadcast_sizes.append(extended_input_shape)
for i in reversed(range(len(extended_input_shape))):
if extended_input_shape[i] != output_shape[i]:
num_rightmost_non_broadcast_elements -= i + 1
break
num_elements_for_alignments = shape_utils.get_num_rightmost_static_elements(
extended_input_shape, num_rightmost_non_broadcast_elements
)
if not min_num_elements:
min_num_elements = num_elements_for_alignments
else:
min_num_elements = min(min_num_elements, num_elements_for_alignments)
alignment = tensor_accessor_codegen.find_max_alignment(
min_num_elements, output_accessors
)
# Note that we use the same alignment for accessing inputs and outputs, although
# they may have different alignment requirements. We may lose perf a little bit,
# but reduce the complexity of our jinja template. We can do some perf
# experiments later to determine if we want to chase more perf gains.
alignment = tensor_accessor_codegen.find_max_alignment(alignment, input_accessors)
return alignment, input_broadcast_sizes, dtype
def _get_dynamic_dims(output_accessors: List[TensorAccessor]) -> List[IntVar]:
res = {}
for output_accessor in output_accessors:
for dim in output_accessor.original_shapes:
if not isinstance(dim, IntImm):
res[dim._attrs["name"]] = dim
return res.values()
def _parse_func_metadata(
ops: List[Operator],
inputs: List[Tensor],
outputs: List[Tensor],
input_accessors: List[TensorAccessor],
output_accessors: List[TensorAccessor],
original_inputs: List[Tensor],
original_outputs: List[Tensor],
backend_spec: BackendSpec,
) -> FusedElementwiseMetaData:
alignment, input_broadcast_sizes, dtype = _get_types_and_sizes(
inputs, input_accessors, output_accessors, backend_spec
)
read_type = backend_spec.get_backend_type(
alignment, dtype, backend_spec.read_num_elements_to_backend_type
)
op_type = backend_spec.get_backend_type(
alignment, dtype, backend_spec.op_num_elements_to_backend_type
)
data_type = backend_spec.dtype_to_backend_type(dtype)
sub_func_metadata, op_type = _get_sub_func_metadata(
ops, data_type, op_type, backend_spec
)
dynamic_dims = _get_dynamic_dims(output_accessors)
return FusedElementwiseMetaData(
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
read_type,
op_type,
data_type,
input_broadcast_sizes,
dynamic_dims,
sub_func_metadata,
)
def _gen_int_var_product_str(
int_vars: List[IntVar],
) -> str:
res = []
for int_var in int_vars:
if isinstance(int_var, IntImm):
res.append(str(int_var._attrs["values"][0]))
elif isinstance(int_var, IntVar):
res.append(int_var._attrs["name"])
else:
raise RuntimeError(
"A dim must be an IntVar! Current type: {}".format(type(int_var))
)
return " * ".join(res)
def _gen_input_broadcast_calculator_str(
input_shape: List[IntVar],
output_shape: List[IntVar],
) -> str:
output_num_elements = []
output_strides = []
input_strides = []
start_idx = 0
for i, (input_dim, output_dim) in enumerate(zip(input_shape, output_shape)):
if input_dim != output_dim:
assert input_dim == IntImm(
1
), "Unexpected shapes! Input: {}, output: {}".format(
input_shape, output_shape
)
input_strides.append(input_shape[i:])
output_strides.append(output_shape[i:])
output_num_elements.append(output_shape[start_idx:])
start_idx = i + 1
if start_idx < len(output_shape):
input_strides.append([IntImm(1)])
output_strides.append([IntImm(1)])
output_num_elements.append(output_shape[start_idx:])
res = []
for (output_num_element, output_stride, input_stride) in zip(
output_num_elements, output_strides, input_strides
):
res.append(
"{} % ({}) / ({}) * ({})".format(
"idx * N_ELEMENTS_PER_THREAD",
_gen_int_var_product_str(output_num_element),
_gen_int_var_product_str(output_stride),
_gen_int_var_product_str(input_stride),
)
)
return " + ".join(res)
def _gen_input_broadcast_size_str(
input_broadcast_sizes: List[List[IntVar]],
output_shape: List[IntVar],
) -> List[str]:
res = []
for input_broadcast_size in input_broadcast_sizes:
if input_broadcast_size is None:
res.append("")
else:
res.append(
_gen_input_broadcast_calculator_str(input_broadcast_size, output_shape)
)
return res
def _gen_dynamic_dim_str(
index_type: str, dynamic_dims: List[IntVar], has_type: bool
) -> str:
type_str = index_type + " " if has_type else ""
res = ", ".join([type_str + dim._attrs["name"] for dim in dynamic_dims])
if res:
res += ", "
return res
def _gen_read_inputs_str(
fused_elementwise_metadata: FusedElementwiseMetaData, broadcast_sizes: List[str]
):
read_inputs = []
for input_idx, (input_accessor, broadcast_size) in enumerate(
zip(fused_elementwise_metadata.input_accessors, broadcast_sizes)
):
input_name = f"input_tmp{input_idx}"
data_idx = (
"idx"
if not broadcast_size
else f"({broadcast_size}) / N_ELEMENTS_PER_THREAD"
)
get_strided_addr_str = GET_STRIDED_ADDRESS_TEMPLATE.render(
tensor_accessor=input_accessor,
data_ptr=input_name,
data_t=fused_elementwise_metadata.data_t,
read_t=fused_elementwise_metadata.read_t,
data_idx=data_idx,
)
read_input = KERNEL_READ_INPUT_TEMPLATE.render(
get_strided_address=get_strided_addr_str,
input_name=input_name,
input_idx=input_idx,
read_t=fused_elementwise_metadata.read_t,
op_t=fused_elementwise_metadata.op_t,
)
read_inputs.append(read_input)
read_inputs_str = "\n".join(read_inputs)
return read_inputs_str
def _gen_write_outputs_str(fused_elementwise_metadata: FusedElementwiseMetaData):
write_outputs = []
for output_idx, output_accessor in enumerate(
fused_elementwise_metadata.output_accessors
):
output_name = f"output{output_idx}"
get_strided_addr_str = GET_STRIDED_ADDRESS_TEMPLATE.render(
tensor_accessor=output_accessor,
data_ptr=output_name,
data_t=fused_elementwise_metadata.data_t,
read_t=fused_elementwise_metadata.read_t,
data_idx="idx",
)
write_out = KERNEL_WRITE_OUTPUT_TEMPLATE.render(
get_strided_address=get_strided_addr_str,
output_name=output_name,
output_idx=output_idx,
)
write_outputs.append(write_out)
write_outputs_str = "\n".join(write_outputs)
return write_outputs_str
def _gen_kernel_function(
func_attrs: Dict[str, Any],
index_type: str,
fused_elementwise_metadata: FusedElementwiseMetaData,
backend_datatype_convertors: Dict[str, Dict[str, str]],
) -> str:
output_params_decl = ",".join(
[
KERNEL_DECL_OUTPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
input_params_decl = ",".join(
[
KERNEL_DECL_INPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
broadcast_sizes = _gen_input_broadcast_size_str(
fused_elementwise_metadata.input_broadcast_sizes,
fused_elementwise_metadata.output_accessors[0].original_shapes,
)
read_inputs_str = _gen_read_inputs_str(fused_elementwise_metadata, broadcast_sizes)
define_outputs = KERNEL_DEFINE_OUTPUTS_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t,
op_t=fused_elementwise_metadata.op_t,
indexes=list(range(len(fused_elementwise_metadata.outputs))),
)
write_outputs_str = _gen_write_outputs_str(fused_elementwise_metadata)
input_names = [
KERNEL_TMP_INPUT_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
output_names = [
KERNEL_TMP_OUTPUT_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
fused_funcs = gen_function_single_thread(
fused_elementwise_metadata,
input_names,
output_names,
backend_datatype_convertors,
)
kernel_func = KERNEL_TEMPLATE.render(
func_name=func_attrs["name"],
index_type=index_type,
output_params=output_params_decl,
input_params=input_params_decl,
dynamic_dims=_gen_dynamic_dim_str(
index_type, fused_elementwise_metadata.dynamic_dims, has_type=True
),
read_inputs=read_inputs_str,
define_outputs=define_outputs,
write_outputs=write_outputs_str,
fused_funcs=fused_funcs,
)
return kernel_func
def fused_elementwise_gen_function(
func_attrs: Dict[str, Any],
custom_libs: str,
head_template: str,
backend_spec: BackendSpec,
) -> str:
"""Generates fused_elementwise function definition."""
ops = func_attrs["elementwise_ops"]
inputs = func_attrs["inputs"]
outputs = func_attrs["outputs"]
input_accessors = func_attrs["input_accessors"]
output_accessors = func_attrs["output_accessors"]
original_inputs = func_attrs["original_inputs"]
original_outputs = func_attrs["original_outputs"]
fused_elementwise_metadata = _parse_func_metadata(
ops,
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
backend_spec,
)
# Dump data types into func_attr for testing purpose.
func_attrs["read_t"] = fused_elementwise_metadata.read_t
func_attrs["op_t"] = fused_elementwise_metadata.op_t
func_attrs["data_t"] = fused_elementwise_metadata.data_t
tensor_accessor_lib = tensor_accessor_codegen.get_libs()
tensor_accessor_lib_str = "\n\n" + tensor_accessor_lib + "\n\n"
kernel_function = _gen_kernel_function(
func_attrs,
backend_spec.index_type,
fused_elementwise_metadata,
backend_spec.backend_datatype_convertors,
)
output_params_decl = ",".join(
[
FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
input_params_decl = ",".join(
[
FUNC_DECL_INPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
kernel_call_output_params = ",".join(
[
KERNEL_CALL_OUTPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
kernel_call_input_params = ",".join(
[
KERNEL_CALL_INPUT_PARAM_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t, idx=i
)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
constant = CONSTANT_TEMPLATE.render(
read_t=fused_elementwise_metadata.read_t,
op_t=fused_elementwise_metadata.op_t,
data_t=fused_elementwise_metadata.data_t,
)
function = FUNC_TEMPLATE.render(
prefix=backend_spec.prefix,
index_type=backend_spec.index_type,
head=backend_spec.header_src_template.render(extra_header=head_template),
constant=constant,
custom_libs=custom_libs,
tensor_accessor_lib=tensor_accessor_lib_str,
kernel_function=kernel_function,
func_name=func_attrs["name"],
output_params=output_params_decl,
input_params=input_params_decl,
dynamic_dims_decl=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=True,
),
dynamic_dims_call=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=False,
),
kernel_call_output_params=kernel_call_output_params,
kernel_call_input_params=kernel_call_input_params,
)
return function
def fused_elementwise_gen_function_decl(
func_attrs,
backend_spec: BackendSpec,
):
"""Generates fused_elementwise function declaration."""
func_name = func_attrs["name"]
ops = func_attrs["elementwise_ops"]
inputs = func_attrs["inputs"]
outputs = func_attrs["outputs"]
input_accessors = func_attrs["input_accessors"]
output_accessors = func_attrs["output_accessors"]
original_inputs = func_attrs["original_inputs"]
original_outputs = func_attrs["original_outputs"]
fused_elementwise_metadata = _parse_func_metadata(
ops,
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
backend_spec,
)
output_params_decl = ",".join(
[
FUNC_DECL_OUTPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.outputs)
]
)
input_params_decl = ",".join(
[
FUNC_DECL_INPUT_PARAM_TEMPLATE.render(idx=i)
for i, _ in enumerate(fused_elementwise_metadata.inputs)
]
)
function_decl = FUNC_DECL_TEMPLATE.render(
prefix=backend_spec.prefix,
index_type=backend_spec.index_type,
func_name=func_name,
output_params=output_params_decl,
input_params=input_params_decl,
dynamic_dims=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=True,
),
)
return function_decl
def fused_elementwise_gen_function_call(
func_attrs,
indent: str,
backend_spec: BackendSpec,
):
"""Generates fused_elementwise function call."""
ops = func_attrs["elementwise_ops"]
inputs = func_attrs["inputs"]
outputs = func_attrs["outputs"]
input_accessors = func_attrs["input_accessors"]
output_accessors = func_attrs["output_accessors"]
original_inputs = func_attrs["original_inputs"]
original_outputs = func_attrs["original_outputs"]
fused_elementwise_metadata = _parse_func_metadata(
ops,
inputs,
outputs,
input_accessors,
output_accessors,
original_inputs,
original_outputs,
backend_spec,
)
output_params = ",".join([output._attrs["name"] for output in outputs])
input_params = ",".join([input._attrs["name"] for input in inputs])
num_elements_calculator = _gen_int_var_product_str(
output_accessors[0].original_shapes
)
return FUNC_CALL_TEMPLATE.render(
stream=backend_spec.stream,
func_name=func_attrs["name"],
index_type=backend_spec.index_type,
calculate_n=num_elements_calculator,
output_params=output_params,
input_params=input_params,
dynamic_dims=_gen_dynamic_dim_str(
backend_spec.index_type,
fused_elementwise_metadata.dynamic_dims,
has_type=False,
),
indent=indent,
)
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Common template for gemm
"""
import os
import re
from collections import OrderedDict
from hashlib import sha1
import jinja2
from ...common import gemm_common
from ...target import Target
# pylint: disable=C0103,C0415,W0611,C0301
EXTRA_SHAPE_TEMPLATE = jinja2.Template(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
"""
)
INSTANCE_TEMPLATE = jinja2.Template(
"""
{{config}}
using {{name}} = {{ config_name }};
"""
)
EXEC_TEMPLATE = jinja2.Template(
"""
{{indent}}auto op = {{instance}}{};
{{indent}}auto invoker = op.MakeInvoker();
{{indent}}auto argument = op.MakeArgument(
{{problem_args}}
{{indent}});
{{indent}}if(!op.IsSupportedArgument(argument)) {
{{indent}} throw std::runtime_error(
{{indent}} "wrong! device_gemm with the specified compilation parameters does "
{{indent}} "not support this Gemm problem");
{{indent}}}
{% if is_profiler %}
{{indent}}auto workspace_size = op.GetWorkSpaceSize(&argument);
{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size;
{% endif %}
{{indent}}invoker.Run(argument, StreamConfig{stream, false});
{{indent}}return;
"""
)
EXTRA_HEADER_TEMPLATE = jinja2.Template(
"""
{% if gemm_flag == "" %}
#include "include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
{% elif gemm_flag == "permute_m2n3" %}
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% elif "bias" in gemm_flag or has_d0 %}
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
{% if gemm_flag == "bias_permute" %}
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
{% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %}
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% endif %}
"""
)
SRC_TEMPLATE = jinja2.Template(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
{{extra_header}}
{{extra_code}}
{{instances}}
void {{function_name}}(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
{% if "bias" in gemm_flag %}
void * bias_ptr,
{% endif %}
{% if has_d0 %}
void * d0_ptr,
{% endif %}
{% if has_d1 %}
void * d1_ptr,
{% endif %}
{% for idx in range(ndims) %}
int64_t* a_dim{{idx}},
{% endfor %}
{% for idx in range(ndims) %}
int64_t* b_dim{{idx}},
{% endfor %}
{% for idx in range(ndims) %}
int64_t* c_dim{{idx}},
{% endfor %}
{% for idx in range(pdims) %}
const int p_dim{{idx}},
{% endfor %}
hipStream_t stream
) {
{{shape_func}}
{{extra_shape}}
{{input_addr_calculator}}
{{output_addr_calculator}}
{{exec_paths}}
throw std::runtime_error(
"Unsupported workload for this gemm specialization."
);
}
"""
)
FUNC_CALL_TEMPLATE = jinja2.Template(
"""
{{indent}}{{func_name}}(
{{indent}} {{in_ptr}},
{{indent}} {{weight_ptr}},
{{indent}} {{out_ptr}},
{% if "bias" in gemm_flag %}
{{indent}} {{bias_ptr}},
{% endif %}
{% if d0_ptr != "" %}
{{indent}} {{d0_ptr}},
{% endif %}
{% if d1_ptr != "" %}
{{indent}} {{d1_ptr}},
{% endif %}
{% for dim in adims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in bdims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in cdims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in pdims %}
{{indent}} {{dim}},
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROBLEM_ARGS_TEMPLATE = jinja2.Template(
"""
{{indent}} static_cast<ck::half_t *>(in_ptr),
{{indent}} static_cast<ck::half_t *>(weight_ptr),
{% if gemm_flag == "bias_permute" %}
{{indent}} static_cast<ck::half_t *>(bias_ptr),
{% elif gemm_flag == "bias_permute_m2n3" %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif gemm_flag == "permute_m2n3" %}
{{indent}} {},
{% else %}
{% if "bias" in gemm_flag and not has_d0 %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif has_d0 and not has_d1 %}
{{indent}} std::array<const void*, 2>{static_cast<ck::half_t *>(bias_ptr),
static_cast<ck::half_t *>(d0_ptr)},
{% elif has_d1 %}
{{indent}} std::array<const void*, 3>{static_cast<ck::half_t *>(bias_ptr),
static_cast<ck::half_t *>(d0_ptr),
static_cast<ck::half_t *>(d1_ptr)},
{% endif %}
{% endif %}
{{indent}} static_cast<ck::half_t *>(out_ptr),
{% if gemm_flag not in ["permute_m2n3", "bias_permute_m2n3", "bias_permute_m3n2"] %}
{{indent}} M,
{{indent}} N,
{{indent}} K,
{{indent}} stride_a,
{{indent}} stride_b,
{% endif %}
{% if gemm_flag == "bias_permute" %}
{{indent}} {M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1},
{{indent}} {M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1},
{% elif gemm_flag in ["permute_m2n3", "bias_permute_m2n3", "bias_permute_m3n2"] %}
{{indent}} a_ms_ks_lengths,
{{indent}} a_ms_ks_strides,
{{indent}} b_ns_ks_lengths,
{{indent}} b_ns_ks_strides,
{% if gemm_flag == "permute_m2n3" %}
{{indent}} {},
{{indent}} {},
{% else %}
{{indent}} std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
{{indent}} std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
{% endif %}
{{indent}} e_ms_ns_lengths,
{{indent}} e_ms_ns_strides,
{% else %}
{% if "bias" in gemm_flag and not has_d0 %}
{{indent}} std::array<ck::index_t, 1>{0},
{% elif has_d0 and not has_d1 %}
{{indent}} std::array<ck::index_t, 2>{0, static_cast<int>(stride_c)},
{% elif has_d1 %}
{{indent}} std::array<ck::index_t, 3>{0, static_cast<int>(stride_c), static_cast<int>(stride_c)},
{% endif %}
{{indent}} stride_c,
{% endif %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{},
{{indent}} ck::tensor_operation::element_wise::PassThrough{},
{% if gemm_flag == "" %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{}
{% elif gemm_flag == "permute_m2n3" %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{}
{% elif gemm_flag == "bias" or "bias_permute" in gemm_flag %}
{{indent}} ck::tensor_operation::element_wise::Add{}
{% elif gemm_flag == "bias_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddRelu{}
{% elif gemm_flag == "bias_fast_gelu" %}
{{indent}} ck::tensor_operation::element_wise::AddFastGelu{}
{% elif gemm_flag == "bias_swish" %}
{{indent}} ck::tensor_operation::element_wise::AddHardswish{}
{% elif gemm_flag == "bias_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddTanh{}
{% elif gemm_flag == "bias_sigmoid" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoid{}
{% elif gemm_flag == "bias_add" %}
{{indent}} ck::tensor_operation::element_wise::AddAdd{}
{% elif gemm_flag == "bias_mul" %}
{{indent}} ck::tensor_operation::element_wise::AddMul{}
{% elif gemm_flag == "bias_mul_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddMulTanh{}
{% elif gemm_flag == "bias_add_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddRelu{}
{% elif gemm_flag == "bias_add_fast_gelu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddFastGelu{}
{% elif gemm_flag == "bias_sigmoid_mul" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoidMul{}
{% elif gemm_flag == "bias_sigmoid_mul_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoidMulTanh{}
{% elif gemm_flag == "bias_mul_add" %}
{{indent}} ck::tensor_operation::element_wise::AddMulAdd{}
{% elif gemm_flag == "bias_add_add" %}
{{indent}} ck::tensor_operation::element_wise::AddAddAdd{}
{% elif gemm_flag == "bias_add_add_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddAddRelu{}
{% endif %}
"""
)
TENSOR_DECL_TEMPLATE = jinja2.Template(
"""
int64_t a_ptr_sz = M*K;
int64_t b_ptr_sz = N*K;
int64_t c_ptr_sz = M*N;
int64_t ptr_max_sz = std::max({a_ptr_sz, b_ptr_sz, c_ptr_sz});
// TODO: special pool size for 8M L2 cache
// need to tune it for other devices
int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 23) / ptr_max_sz)));
memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // x: index 0
memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // w: index 1
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // y: index 2
{% if "bias" in gemm_flag %}
memory_pool->AllocateHalfTensor(N, mem_pool_sz); // b: index 3
{% endif %}
{% if has_d0 %}
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d0 ptr: index 4
{% endif %}
{% if has_d1 %}
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d1 ptr: index 5
{% endif %}
"""
)
STRUCTS_DEF_TEMPLATE = jinja2.Template(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
PROFILER_TEMPLATE = jinja2.Template(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
if (argc < 4) {
throw std::runtime_error("wrong params");
}
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// TODO: random init
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
auto timer = new KernelTimerImpl();
timer->Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer->End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer->GetElapsedTime() << std::endl;
delete(timer);
}
"""
)
FUNC_DECL_TEMPLATE = jinja2.Template(
"""
void {{func_name}}(
void *,
void *,
void *,
{% if "bias" in gemm_flag %}
void *,
{% endif %}
{% if has_d0 %}
void *,
{% endif %}
{% if has_d1 %}
void *,
{% endif %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(pdims) %}
const int,
{% endfor %}
hipStream_t
);
"""
)
def has_d0(func_attrs):
return func_attrs.get("num_sources", 0) >= 1
def has_d1(func_attrs):
return func_attrs.get("num_sources", 0) >= 2
def emit_instance(op):
"""Emit instance."""
import ck_lib # noqa: F401
op_def = op.emit()
return op_def
def extract_config(op_kind, extra_kind, f_proc_op):
"""Extract (operation name, operation instance) pair
from all operation candidates.
Parameters
----------
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
f_prop_op: function
Used to filter operation.
Returns
-------
Dict
Extracted (operation name, operation instance) pair.
"""
gemm_ops = OrderedDict()
extract_ops = list(Target.current()._operators[op_kind][extra_kind].items())
for key, value in extract_ops:
op = value[0]
ret = f_proc_op(op)
if len(ret) > 0:
for op_inst in ret:
gemm_ops[key] = op_inst
return gemm_ops
def extract_config_name(config):
"""Exract name from the statement, e.g. 'model' for 'using model = xxx'.
Parameters
----------
config : str
Configuration as a string in the format of 'using model = xxx'.
Returns
-------
str
Extracted name from the statement.
Raises
------
RuntimeError
Invalid config.
"""
pattern = re.compile(r"\s*using\s(.*?)\s=")
decl = config.split("\n")[1]
match = pattern.match(decl)
if match is None:
raise RuntimeError("Invalid config: \n" + config)
return match.groups()[0]
def gen_profiler(
func_attrs,
workdir,
dim_info_dict,
args_parse,
gemm_flag,
extra_code="",
ndims=2,
extra_shape_template=EXTRA_SHAPE_TEMPLATE,
problem_args_template=PROBLEM_ARGS_TEMPLATE,
extra_header_template=EXTRA_HEADER_TEMPLATE,
tensor_decl_template=TENSOR_DECL_TEMPLATE,
):
"""Generates standalone executables for profiler.
Parameters
----------
func_attrs : Dict
Operation attributes.
workdir : str
Directory to store the generated outputs.
dim_info_dict: Dict[str, DimInfo]
Generated from gemm._extract_dims().
Used to store mapping between dim_names to input / output tensor dims.
args_parse: str
Profiler input argument parser.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu','bias_sigmoid','bias_add_relu'.
extra_code : str
Extra code for self-defined operators.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm
extra_shape_template: jinja2.Template
Shape evaluation template.
problem_args_template: jinja2.Template
Problem args template for profiler.
extra_header_template: jinja2.Template
Extra header template as we have different headers for gemm and bmm.
tensor_decl_template: jinja2.Template
Tensor declaration template.
"""
op_type = func_attrs["op"]
op_instance = func_attrs["op_instance"]
# shape function
op_func_shape = gemm_common.gen_shape_eval_code(
indent=2, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True
)
adims = ["&a_dim" + str(i) for i in range(ndims)]
bdims = ["&b_dim" + str(i) for i in range(ndims)]
cdims = ["&c_dim" + str(i) for i in range(ndims)]
pdims = []
if func_attrs.get("shape") is not None:
pdims = ["p_dim" + str(i) for i in range(len(func_attrs["shape"]))]
extra_shape_func = extra_shape_template.render(indent=" ")
file_pairs = []
has_d0_flag = has_d0(func_attrs)
has_d1_flag = has_d1(func_attrs)
for op_name, op in op_instance.items():
config = emit_instance(op)
config_name = extract_config_name(config)
instance = INSTANCE_TEMPLATE.render(
name="DeviceGemmInstance", config_name=config_name, config=config
)
problem_args = problem_args_template.render(
indent=" ",
gemm_flag=gemm_flag,
has_d0=has_d0_flag,
has_d1=has_d1_flag,
)
exec_program = EXEC_TEMPLATE.render(
indent=" ",
instance="DeviceGemmInstance",
problem_args=problem_args,
is_profiler=True,
)
extra_header = extra_header_template.render(
gemm_flag=gemm_flag, has_d0=has_d0_flag
)
op_func = SRC_TEMPLATE.render(
instances=instance,
function_name="gemm",
ndims=ndims,
pdims=len(pdims),
has_d0=has_d0_flag,
has_d1=has_d1_flag,
shape_func=op_func_shape,
extra_shape=extra_shape_func,
exec_paths=exec_program,
extra_code=extra_code,
gemm_flag=gemm_flag,
extra_header=extra_header,
)
structs_def = STRUCTS_DEF_TEMPLATE.render()
tensor_decl = tensor_decl_template.render(
gemm_flag=gemm_flag, has_d0=has_d0_flag, has_d1=has_d1_flag
)
func_call = FUNC_CALL_TEMPLATE.render(
func_name="gemm",
in_ptr="(void *) memory_pool->RequestHalfTensorByIdx(0)",
weight_ptr="(void *) memory_pool->RequestHalfTensorByIdx(1)",
out_ptr="(void *) memory_pool->RequestHalfTensorByIdx(2)",
bias_ptr="(void *) memory_pool->RequestHalfTensorByIdx(3)",
d0_ptr="(void *) memory_pool->RequestHalfTensorByIdx(4)"
if has_d0_flag
else "",
d1_ptr="(void *) memory_pool->RequestHalfTensorByIdx(5)"
if has_d1_flag
else "",
adims=adims,
bdims=bdims,
cdims=cdims,
pdims=pdims,
gemm_flag=gemm_flag,
)
code = PROFILER_TEMPLATE.render(
structs_def=structs_def,
op_func=op_func,
args_parse=args_parse,
tensor_decl=tensor_decl,
func_call=func_call,
)
prefix = os.path.join(workdir, "profiler", op_type)
if not os.path.exists(prefix):
os.makedirs(prefix)
src_path = os.path.join(prefix, op_name + ".cpp")
obj_path = os.path.join(prefix, op_name)
if os.path.exists(obj_path):
continue
with open(src_path, "w") as fo:
fo.write(code)
file_pairs.append((src_path, obj_path))
return file_pairs
def gen_function(
func_attrs,
exec_cond_template,
dim_info_dict,
gemm_flag,
extra_code="",
ndims=2,
extra_shape_template=EXTRA_SHAPE_TEMPLATE,
problem_args_template=PROBLEM_ARGS_TEMPLATE,
extra_header_template=EXTRA_HEADER_TEMPLATE,
input_addr_calculator="",
output_addr_calculator="",
):
"""Generate function body.
Parameters
----------
func_attrs : Dict
Operation attributes.
exec_cond_template : jinja2.Template
Generates if statement to execute kernel.
dim_info_dict: Dict[str, DimInfo]
Generated from gemm._extract_dims().
Used to store mapping between dim_names to input / output tensor dims.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu','bias_add_relu'.
extra_code : str
Extra code for self-defined operators.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm.
extra_shape_template: jinja2.Template
Shape evaluation template.
extra_header_template: jinja2.Template
Extra header template as we have different headers for gemm and bmm.
input_addr_calculator : str
Used to adjust input address based on input tensor accessors if accessors exist
output_addr_calculator : str
Used to adjust output address based on output tensor accessors if accessors exist
Returns
-------
str
The rendered template of generated function body.
"""
func_name = func_attrs["name"]
exec_path = func_attrs["exec_path"]
op_instance = func_attrs["op_instance"]
inst_def_flag = set()
instances = {}
instance_decl = ""
has_d0_flag = has_d0(func_attrs)
has_d1_flag = has_d1(func_attrs)
for key, value in exec_path.items():
fname = "f" + sha1(key.encode()).hexdigest()
algo = value.algo
if algo not in inst_def_flag:
config = emit_instance(op_instance[algo])
inst_def_flag.add(algo)
else:
config = ""
inst = INSTANCE_TEMPLATE.render(
config=config, name=fname, config_name=extract_config_name(config)
)
instances[key] = inst
instance_decl += inst
extra_shape_func = extra_shape_template.render(indent=" ")
shape_eval_func = gemm_common.gen_shape_eval_code(
indent=1, dtype="int64_t", dim_info_dict=dim_info_dict, is_ptr=True
)
exec_paths = ""
for key, _ in instances.items():
fname = "f" + sha1(key.encode()).hexdigest()
problem_args = problem_args_template.render(
indent=" ",
gemm_flag=gemm_flag,
has_d0=has_d0_flag,
has_d1=has_d1_flag,
)
program = EXEC_TEMPLATE.render(
indent=" ",
instance=fname,
problem_args=problem_args,
is_profiler=False,
)
exec_inst = exec_cond_template.render(indent=" ", cond=key, program=program)
exec_paths += exec_inst
extra_header = extra_header_template.render(
gemm_flag=gemm_flag, has_d0=has_d0(func_attrs)
)
pdims = len(func_attrs["shape"]) if func_attrs.get("shape") is not None else 0
return SRC_TEMPLATE.render(
instances=instance_decl,
function_name=func_name,
shape_func=shape_eval_func,
extra_shape=extra_shape_func,
input_addr_calculator=input_addr_calculator,
output_addr_calculator=output_addr_calculator,
exec_paths=exec_paths,
extra_code=extra_code,
extra_header=extra_header,
gemm_flag=gemm_flag,
ndims=ndims,
pdims=pdims,
has_d0=has_d0_flag,
has_d1=has_d1_flag,
)
def gen_function_decl(func_name, gemm_flag, ndims=2, pdims=0, has_d0="", has_d1=""):
"""Generates function declarations.
Parameters
----------
func_attrs : Dict
Operation attributes.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu'.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm.
Returns
-------
str
The rentered template of function declaration.
"""
return FUNC_DECL_TEMPLATE.render(
func_name=func_name,
gemm_flag=gemm_flag,
ndims=ndims,
pdims=pdims,
has_d0=has_d0,
has_d1=has_d1,
)
def gen_function_call(func_attrs, indent=" ", gemm_flag=""):
"""Generates function call.
Parameters
----------
func_attrs : Dict
Stores the operation attributes.
indent : str, optional
Indent for codegen, target dependent e.g. C++, python, etc., by default " ".
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu'.
Returns
-------
str
The rendered template of generated function call.
"""
a = func_attrs["inputs"][0]
b = func_attrs["inputs"][1]
c = func_attrs["outputs"][0]
bias_ptr = ""
if "bias" in gemm_flag:
bias = func_attrs["inputs"][2]
bias_ptr = bias._attrs["name"]
d0_ptr = ""
if has_d0(func_attrs):
d0 = func_attrs["inputs"][3]
d0_ptr = d0._attrs["name"]
d1_ptr = ""
if has_d1(func_attrs):
d1 = func_attrs["inputs"][4]
d1_ptr = d1._attrs["name"]
adims = [
"&" + dim._attrs["name"]
for dim in func_attrs["input_accessors"][0].original_shapes
]
bdims = [
"&" + dim._attrs["name"]
for dim in func_attrs["input_accessors"][1].original_shapes
]
cdims = [
"&" + dim._attrs["name"]
for dim in func_attrs["output_accessors"][0].original_shapes
]
pdims = []
if func_attrs.get("shape") is not None:
pdims = list(func_attrs["shape"])
return FUNC_CALL_TEMPLATE.render(
func_name=func_attrs["name"],
in_ptr=a._attrs["name"],
weight_ptr=b._attrs["name"],
out_ptr=c._attrs["name"],
bias_ptr=bias_ptr,
d0_ptr=d0_ptr,
d1_ptr=d1_ptr,
adims=adims,
bdims=bdims,
cdims=cdims,
pdims=pdims,
indent=indent,
gemm_flag=gemm_flag,
)
def default_fproc_f16(*, op, a_layout, b_layout, c_layout):
"""Filter the input operation by layouts.
Parameters
----------
op: operation
aitemplate operation
a_layout: ck_lib.library.LayoutType
a layout type.
b_layout: ck_lib.library.LayoutType
b layout type.
c_layout: ck_lib.library.LayoutType
c layout type.
Returns
-------
List
List of filtered op (can be empty).
"""
import copy
import ck_lib
ret = []
data_type = ck_lib.library.DataType.f16
acc_type = ck_lib.library.DataType.f32
if (
op.A.element == data_type
and op.B.element == data_type
and op.C.element == data_type
and op.accumulator_type() == acc_type
and op.A.layout == a_layout
and op.B.layout == b_layout
and op.C.layout == c_layout
):
ret += [copy.deepcopy(op)]
return ret
def make_fproc_f16(func_attrs, layout, op_kind, extra_kind):
"""This function sets a callback for processing the epilogue of the kernel
associated with func_attrs.
Parameters
----------
func_attrs: Dictionary
kernel attributes dictionary
layout: layout object
kernel layout
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
"""
def fproc_f16(op):
a_layout, b_layout, c_layout = layout.ck_lib_layouts()
return default_fproc_f16(
op=op,
a_layout=a_layout,
b_layout=b_layout,
c_layout=c_layout,
)
func_attrs["op_instance"] = extract_config(op_kind, extra_kind, fproc_f16)
\ No newline at end of file
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1,
256, // block_size
64, // m_per_block
128, // n_per_block
32, // k_per_block
8, // ak1
2, // bk1
32, // m_per_xdl
32, // n_per_xdl
1, // m_xdl_per_wave
2, // n_xdl_per_wave
ck::Sequence<4,64,1>, // thread_cluster_length
ck::Sequence<1,0,2>, // thread_cluster_arrange_order
ck::Sequence<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<8,32,1>, // thread_cluster_length
ck::Sequence<0,2,1>, // thread_cluster_arrange_order
ck::Sequence<0,2,1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
1, // n_xdl_per_wave
ck::Sequence<1,32,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector
>;
using fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0 = gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT;
void gemm_rrr_3(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
int64_t* a_dim0,
int64_t* a_dim1,
int64_t* b_dim0,
int64_t* b_dim1,
int64_t* c_dim0,
int64_t* c_dim1,
hipStream_t stream
) {
ck::index_t M = (*a_dim0);
ck::index_t N = (*b_dim1);
ck::index_t K = (*a_dim1);
int64_t offset_a = 0;
int64_t offset_b = 0;
int64_t offset_c = 0;
ck::index_t stride_a = *a_dim1;
ck::index_t stride_b = *b_dim1;
ck::index_t stride_c = *c_dim1;
if (M == 256 && N == 32 && K == 128) {
auto op = fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(
static_cast<ck::half_t *>(in_ptr) + offset_a,
static_cast<ck::half_t *>(weight_ptr) + offset_b,
static_cast<ck::half_t *>(out_ptr) + offset_c,
M,
N,
K,
stride_a,
stride_b,
stride_c,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}
);
if(!op.IsSupportedArgument(argument)) {
LOG(FATAL) << "wrong! " << op.GetTypeString() << " with the specified compilation parameters does not support this Gemm problem.";
}
invoker.Run(argument, StreamConfig{stream, false});
return;
}
LOG(FATAL) << "Unsupported workload for this gemm specialization.";
}
\ No newline at end of file
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1,
256, // block_size
64, // m_per_block
128, // n_per_block
32, // k_per_block
8, // ak1
2, // bk1
32, // m_per_xdl
32, // n_per_xdl
1, // m_xdl_per_wave
2, // n_xdl_per_wave
ck::Sequence<4,64,1>, // thread_cluster_length
ck::Sequence<1,0,2>, // thread_cluster_arrange_order
ck::Sequence<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
1, // add_extra_dim
ck::Sequence<8,32,1>, // thread_cluster_length
ck::Sequence<0,2,1>, // thread_cluster_arrange_order
ck::Sequence<0,2,1>, // src_access_order
1, // src_vector_dim
4, // src_scalar_per_vector
2, // dst_scalar_per_vector
0, // add_extra_dim
1, // m_xdl_per_wave
1, // n_xdl_per_wave
ck::Sequence<1,32,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector
>;
using fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0 = gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT;
void gemm_rrr_3(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
int64_t* a_dim0,
int64_t* a_dim1,
int64_t* b_dim0,
int64_t* b_dim1,
int64_t* c_dim0,
int64_t* c_dim1,
hipStream_t stream
) {
ck::index_t M = (*a_dim0);
ck::index_t N = (*b_dim1);
ck::index_t K = (*a_dim1);
int64_t offset_a = 0;
int64_t offset_b = 0;
int64_t offset_c = 0;
ck::index_t stride_a = *a_dim1;
ck::index_t stride_b = *b_dim1;
ck::index_t stride_c = *c_dim1;
if (M == 256 && N == 32 && K == 128) {
auto op = fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(
static_cast<ck::half_t *>(in_ptr) + offset_a,
static_cast<ck::half_t *>(weight_ptr) + offset_b,
static_cast<ck::half_t *>(out_ptr) + offset_c,
M,
N,
K,
stride_a,
stride_b,
stride_c,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}
);
if(!op.IsSupportedArgument(argument)) {
LOG(FATAL) << "wrong! " << op.GetTypeString() << " with the specified compilation parameters does not support this Gemm problem.";
}
invoker.Run(argument, StreamConfig{stream, false});
return;
}
LOG(FATAL) << "Unsupported workload for this gemm specialization.";
}
\ No newline at end of file
size_t GLOBAL_WORKSPACE_SIZE = 0;
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
using layernorm_rank2_256_1_256_1_8_1_8_1_8_1_8_8 = ck::tensor_operation::device::DeviceLayernormImpl<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::half_t,
ck::tensor_operation::element_wise::PassThrough,
2,
1,
256, // block_size
1, // m_cluster_size
256, // k_cluster_size
1, // m_slice_size
8, // k_slice_size
1, // in_src_dim
8, // in_src_size
1, // gamma_src_dim
8, // gamma_src_size
1, // beta_src_dim
8, // beta_src_size
8 // out_dst_size
>;
using DeviceInstance = layernorm_rank2_256_1_256_1_8_1_8_1_8_1_8_8;
void layernorm_0(void* input,
void* gamma,
void* beta,
void* output,
int64_t* in_0,
int64_t* in_1,
hipStream_t stream)
{
int M = *in_0;
int N = *in_1;
std::vector<ck::index_t> i_inStrides;
i_inStrides.push_back(N);
i_inStrides.push_back(1);
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer(
{M, N},
i_inStrides,
std::vector<ck::index_t>{0, 1},
std::vector<ck::index_t>{0, 1},
i_inStrides,
{1},
1e-05,
static_cast<ck::half_t *>(input),
static_cast<ck::half_t *>(gamma),
static_cast<ck::half_t *>(beta),
static_cast<ck::half_t *>(output),
ck::tensor_operation::element_wise::PassThrough{}
);
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_layernorm with the specified compilation parameters does "
"not support this Softmax problem");
};
std::string instance_name = device_instance.GetTypeString();
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{stream, false});
return;
}
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
int main(int argc, char** argv) {
const int64_t in_0 = std::stoi(argv[1]);
const int64_t in_1 = std::stoi(argv[2]);
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
int64_t ptr_sz = in_0 * in_1;
int64_t norm_dim = in_1;
// TODO: special pool size for 8M L2 cache
// need to tune it for other devices
int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 23) / ptr_sz)));
memory_pool->AllocateHalfTensor(ptr_sz, mem_pool_sz); // in: index 0
memory_pool->AllocateHalfTensor(ptr_sz, mem_pool_sz); // out: index 1
memory_pool->AllocateHalfTensor(norm_dim, mem_pool_sz); // gamma: index 2
memory_pool->AllocateHalfTensor(norm_dim, mem_pool_sz); // beta: index 3
// warmup
for(int i = 0; i < 3; ++i) {
layernorm_0(
(void *) memory_pool->RequestHalfTensorByIdx(0),
(void *) memory_pool->RequestHalfTensorByIdx(2),
(void *) memory_pool->RequestHalfTensorByIdx(3),
(void *) memory_pool->RequestHalfTensorByIdx(1),
const_cast<int64_t *>(&in_0),
const_cast<int64_t *>(&in_1),
stream
);
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
layernorm_0(
(void *) memory_pool->RequestHalfTensorByIdx(0),
(void *) memory_pool->RequestHalfTensorByIdx(2),
(void *) memory_pool->RequestHalfTensorByIdx(3),
(void *) memory_pool->RequestHalfTensorByIdx(1),
const_cast<int64_t *>(&in_0),
const_cast<int64_t *>(&in_1),
stream
);
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
\ No newline at end of file
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <string>
#include <unordered_map>
#include <math.h>
void gemm_rrr_3(
void *,
void *,
void *,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
hipStream_t
);
#define CHECK_VECTOR_ACCESS(vector, idx) \
if (idx >= vector.size()) { \
throw std::out_of_range( \
"[__func__]: index out of range, " #vector ".size()=" + \
std::to_string(vector.size()) + ", got " + std::to_string(idx)); \
}
namespace ait {
namespace {
void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") + GetLastErrorString() +
" enum: " + std::to_string(device_error) +
" at " + file + ": " + std::to_string(line);
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
thread_local bool target_has_graph_mode = false;
} // namespace
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class Model {
public:
Model(
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
uint8_t* constants,
AITemplateAllocator& allocator)
: blob_(RAII_DeviceMalloc(blob_size, allocator)),
workspace_(RAII_DeviceMalloc(workspace_size, allocator)),
params_(num_inputs + num_outputs + num_unbound_constants),
num_inputs_(num_inputs),
num_outputs_(num_outputs),
constants_(constants) {
dmlc::InitLogging("aitemplate"); // TODO(xxx): render network name
LOG(INFO) << "Init AITemplate Runtime.";
global_workspace_ = static_cast<uint8_t*>(workspace_.get()) + 0;
unique_workspace_ = static_cast<uint8_t*>(workspace_.get());
DEVICE_CHECK(GetDevice(&device_idx_))
DEVICE_CHECK(CreateEvent(&run_finished_));
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
DEVICE_CHECK(cudaDeviceGetAttribute(
&max_smem_size_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx_));
#endif
DEVICE_CHECK(GetDeviceProperties(&device_properties_, device_idx_));
DEVICE_CHECK(StreamCreate(&graph_capture_stream_, /*non_blocking=*/true));
InitConstants(constants_);
auto* blob_ptr = static_cast<uint8_t*>(blob_.get());
}
~Model() {
if (run_finished_ != nullptr) {
DestroyEvent(run_finished_);
}
if (graph_capture_stream_ != nullptr) {
StreamDestroy(graph_capture_stream_);
}
if (graph_exec_ != nullptr) {
GraphExecDestroy(graph_exec_);
}
}
Model(Model&& other) {
run_finished_ = other.run_finished_;
graph_exec_ = other.graph_exec_;
graph_capture_stream_ = other.graph_capture_stream_;
other.run_finished_ = nullptr;
other.graph_exec_ = nullptr;
other.graph_capture_stream_ = nullptr;
constants_ = other.constants_;
num_inputs_ = other.num_inputs_;
global_workspace_ = other.global_workspace_;
unique_workspace_ = other.unique_workspace_;
workspace_ = std::move(other.workspace_);
params_ = std::move(other.params_);
constant_name_to_ptr_ = std::move(other.constant_name_to_ptr_);
// Re-wire the pointers in the above 2 structures.
InitConstants(constants_);
}
Model& operator=(Model&&) = delete;
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
void SetUpInputsOutputs() {
input_0 = static_cast<decltype(input_0)>(params_[0].ptr);
if (input_0 == nullptr) {
throw std::runtime_error("Constant input_0 was not set! Set the value with set_constant.");
}
input_1 = static_cast<decltype(input_1)>(params_[1].ptr);
if (input_1 == nullptr) {
throw std::runtime_error("Constant input_1 was not set! Set the value with set_constant.");
}
output_0 = static_cast<decltype(output_0)>(params_[2].ptr);
if (output_0 == nullptr) {
throw std::runtime_error("Constant output_0 was not set! Set the value with set_constant.");
}
}
void DeviceToDeviceCopies(StreamType stream) {
}
void Run(StreamType stream, bool graph_mode) {
SetUpInputsOutputs();
if (target_has_graph_mode && graph_mode) {
RunAsGraph(stream);
} else {
RunImpl(stream);
}
DEVICE_CHECK(EventRecord(run_finished_, stream));
}
void RunImpl(StreamType stream) {
gemm_rrr_3(
input_0,
input_1,
output_0,
&input_0_dim_0,
&input_0_dim_1,
&input_1_dim_0,
&input_1_dim_1,
&input_0_dim_0,
&input_1_dim_1,
stream
);
DeviceCheckLastError(__FILE__, __LINE__);
DeviceToDeviceCopies(stream);
}
bool IsPending() {
auto query = QueryEvent(run_finished_);
if (query == GetDeviceNotReady()) {
return true;
}
if (query != GetDeviceSuccess()) {
LOG(WARNING) << "Pending model run did not finish successfully. Error: "
<< GetErrorString(query);
}
return false;
}
void WaitForCompletion() {
DEVICE_CHECK(EventSynchronize(run_finished_));
}
size_t NumInputs() const {
return num_inputs_;
}
size_t NumOutputs() const {
return num_outputs_;
}
void SetParam(const void* src, size_t param_idx) {
CHECK_VECTOR_ACCESS(params_, param_idx)
// const_cast is not ideal here, but it is unfortunately
// necessary:
// 1) We store outputs and inputs in the same vector,
// and outputs cannot be const.
// 2) Most of the codegen is not const-correct (most ops
// require non-const pointers). So even if we put const
// pointers into params, a const_cast would be required
// somewhere else.
params_[param_idx].ptr = const_cast<void*>(src);
}
void SetInput(const void* src, const AITemplateParamShape& shape, size_t idx) {
SetInputShape(shape, idx);
SetParam(src, idx);
}
void SetOutput(void* src, size_t idx) {
SetParam(src, idx + num_inputs_);
}
// Write the (possibly dynamic) output shape to the given pointer.
// Note that this should be called _after_ the shape inference in
// Run() is finished. output_shape_out should be able to store
// at least GetOutputMaximumShape(idx).size values.
void GetOutputShape(size_t idx, int64_t* output_shape_out) {
const auto param_idx = idx + num_inputs_;
CHECK_VECTOR_ACCESS(params_, param_idx);
const auto& shape_ptrs = params_[param_idx].shape_ptrs;
for (size_t i = 0; i < shape_ptrs.size(); ++i) {
output_shape_out[i] = shape_ptrs[i].GetValue();
}
}
void SetConstant(const char* name, const void* src) {
auto it = constant_name_to_ptr_.find(name);
if (it == constant_name_to_ptr_.end()) {
throw std::out_of_range(std::string("Could not find constant ") + name);
}
const void** ptr = it->second;
*ptr = src;
}
private:
void InitConstants(uint8_t* constants) {
params_[0].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(128, 128, &input_0_dim_1)};
params_[1].shape_ptrs = {ParamDim(128, 128, &input_1_dim_0), ParamDim(32, 32, &input_1_dim_1)};
params_[2].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(32, 32, &input_1_dim_1)};
}
void SetInputShape(const AITemplateParamShape& shape, size_t idx) {
auto& param = params_[idx];
if (shape.size != param.shape_ptrs.size()) {
throw std::runtime_error(
"[SetInputShape] Got wrong param shape for input " + std::to_string(idx) +
"; expected " + std::to_string(param.shape_ptrs.size()) + ", got " +
std::to_string(shape.size));
}
for (size_t i = 0; i < param.shape_ptrs.size(); ++i) {
param.shape_ptrs[i].SetValue(shape.shape_data[i]);
}
}
DeviceError EndCapture(GraphType* graph_ptr) {
auto err = StreamEndCapture(graph_capture_stream_, graph_ptr);
if (err != GetDeviceSuccess()) {
// If we can't take the stream out of capture mode, something is probably
// wrong with CUDA graph for this model (e.g. there might have been an
// illegal capture mode operation). Disable graph mode to avoid such issues
// in future iterations.
target_has_graph_mode = false;
LOG(WARNING) << "Graph capture failed to end. Disabling graph mode.";
return err;
}
return GetDeviceSuccess();
}
void RunAsGraph(StreamType stream) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING) << "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr){ return EndCapture(graph_ptr); }
);
if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (GraphExecUpdate(graph_exec_, graph.get()) != GetDeviceSuccess()) {
// Consume the last cuda error, which may affect the next GraphExecLaunch
// call.
GetLastError();
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
int device_idx_;
int max_smem_size_{0};
DevicePropertyType device_properties_;
// This event tracks when the inference is finished
// so that this Model may be reclaimed by its owning
// ModelContainer.
EventType run_finished_;
// A blob of memory used for storing intermediate tensors.
GPUPtr blob_;
// Memory for constants that were folded into the *.so. Unowned by Model,
// owned by ModelContainer.
// TODO: make this const. It can't be const right now because we derive
// tensor pointers from it, and no tensor pointers are const.
uint8_t* constants_;
size_t num_inputs_;
size_t num_outputs_;
// The workspace blob is used as scratch memory. See
// _generate_workspace in memory planning for more information.
GPUPtr workspace_;
uint8_t* global_workspace_{nullptr};
uint8_t* unique_workspace_{nullptr};
class ParamDim {
public:
ParamDim(int64_t lower_bound, int64_t upper_bound, int64_t* value) :
lower_bound_(lower_bound),
upper_bound_(upper_bound),
value_(value) {}
void SetValue(int64_t new_value) {
if (new_value < lower_bound_ || new_value > upper_bound_) {
throw std::out_of_range(
"[SetValue] Dimension got value out of bounds; expected value to be in [" +
std::to_string(lower_bound_) + ", " + std::to_string(upper_bound_) + "], but got " +
std::to_string(new_value)
);
}
*value_ = new_value;
}
int64_t GetValue() const {
return *value_;
}
private:
int64_t lower_bound_;
int64_t upper_bound_;
int64_t* value_;
};
struct ParamInfo {
void* ptr = nullptr;
// TODO add offset
const char* name;
std::vector<ParamDim> shape_ptrs;
};
// Contains info for all tensors marked as inputs
// or outputs. The first num_inputs elements are the inputs.
// Constants are not included.
std::vector<ParamInfo> params_;
GraphExecType graph_exec_ = nullptr;
StreamType graph_capture_stream_;
std::unordered_map<std::string, const void**> constant_name_to_ptr_;
void * input_0 {nullptr};
void * input_1 {nullptr};
void * output_0 {nullptr};
int64_t input_0_dim_0 { 256 };
int64_t input_0_dim_1 { 128 };
int64_t input_1_dim_0 { 128 };
int64_t input_1_dim_1 { 32 };
};
} // namespace ait
\ No newline at end of file
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <string>
#include <unordered_map>
#include <math.h>
void gemm_rrr_3(
void *,
void *,
void *,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
int64_t*,
hipStream_t
);
#define CHECK_VECTOR_ACCESS(vector, idx) \
if (idx >= vector.size()) { \
throw std::out_of_range( \
"[__func__]: index out of range, " #vector ".size()=" + \
std::to_string(vector.size()) + ", got " + std::to_string(idx)); \
}
namespace ait {
namespace {
void DeviceCheckLastError(const char* file, int line) {
auto device_error = GetLastError();
if (device_error != GetDeviceSuccess()) {
std::string msg = std::string("Got error: ") + GetLastErrorString() +
" enum: " + std::to_string(device_error) +
" at " + file + ": " + std::to_string(line);
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
thread_local bool target_has_graph_mode = false;
} // namespace
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class Model {
public:
Model(
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
uint8_t* constants,
AITemplateAllocator& allocator)
: blob_(RAII_DeviceMalloc(blob_size, allocator)),
workspace_(RAII_DeviceMalloc(workspace_size, allocator)),
params_(num_inputs + num_outputs + num_unbound_constants),
num_inputs_(num_inputs),
num_outputs_(num_outputs),
constants_(constants) {
dmlc::InitLogging("aitemplate"); // TODO(xxx): render network name
LOG(INFO) << "Init AITemplate Runtime.";
global_workspace_ = static_cast<uint8_t*>(workspace_.get()) + 0;
unique_workspace_ = static_cast<uint8_t*>(workspace_.get());
DEVICE_CHECK(GetDevice(&device_idx_))
DEVICE_CHECK(CreateEvent(&run_finished_));
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
DEVICE_CHECK(cudaDeviceGetAttribute(
&max_smem_size_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx_));
#endif
DEVICE_CHECK(GetDeviceProperties(&device_properties_, device_idx_));
DEVICE_CHECK(StreamCreate(&graph_capture_stream_, /*non_blocking=*/true));
InitConstants(constants_);
auto* blob_ptr = static_cast<uint8_t*>(blob_.get());
}
~Model() {
if (run_finished_ != nullptr) {
DestroyEvent(run_finished_);
}
if (graph_capture_stream_ != nullptr) {
StreamDestroy(graph_capture_stream_);
}
if (graph_exec_ != nullptr) {
GraphExecDestroy(graph_exec_);
}
}
Model(Model&& other) {
run_finished_ = other.run_finished_;
graph_exec_ = other.graph_exec_;
graph_capture_stream_ = other.graph_capture_stream_;
other.run_finished_ = nullptr;
other.graph_exec_ = nullptr;
other.graph_capture_stream_ = nullptr;
constants_ = other.constants_;
num_inputs_ = other.num_inputs_;
global_workspace_ = other.global_workspace_;
unique_workspace_ = other.unique_workspace_;
workspace_ = std::move(other.workspace_);
params_ = std::move(other.params_);
constant_name_to_ptr_ = std::move(other.constant_name_to_ptr_);
// Re-wire the pointers in the above 2 structures.
InitConstants(constants_);
}
Model& operator=(Model&&) = delete;
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
void SetUpInputsOutputs() {
input_0 = static_cast<decltype(input_0)>(params_[0].ptr);
if (input_0 == nullptr) {
throw std::runtime_error("Constant input_0 was not set! Set the value with set_constant.");
}
input_1 = static_cast<decltype(input_1)>(params_[1].ptr);
if (input_1 == nullptr) {
throw std::runtime_error("Constant input_1 was not set! Set the value with set_constant.");
}
output_0 = static_cast<decltype(output_0)>(params_[2].ptr);
if (output_0 == nullptr) {
throw std::runtime_error("Constant output_0 was not set! Set the value with set_constant.");
}
}
void DeviceToDeviceCopies(StreamType stream) {
}
void Run(StreamType stream, bool graph_mode) {
SetUpInputsOutputs();
if (target_has_graph_mode && graph_mode) {
RunAsGraph(stream);
} else {
RunImpl(stream);
}
DEVICE_CHECK(EventRecord(run_finished_, stream));
}
void RunImpl(StreamType stream) {
gemm_rrr_3(
input_0,
input_1,
output_0,
&input_0_dim_0,
&input_0_dim_1,
&input_1_dim_0,
&input_1_dim_1,
&input_0_dim_0,
&input_1_dim_1,
stream
);
DeviceCheckLastError(__FILE__, __LINE__);
DeviceToDeviceCopies(stream);
}
bool IsPending() {
auto query = QueryEvent(run_finished_);
if (query == GetDeviceNotReady()) {
return true;
}
if (query != GetDeviceSuccess()) {
LOG(WARNING) << "Pending model run did not finish successfully. Error: "
<< GetErrorString(query);
}
return false;
}
void WaitForCompletion() {
DEVICE_CHECK(EventSynchronize(run_finished_));
}
size_t NumInputs() const {
return num_inputs_;
}
size_t NumOutputs() const {
return num_outputs_;
}
void SetParam(const void* src, size_t param_idx) {
CHECK_VECTOR_ACCESS(params_, param_idx)
// const_cast is not ideal here, but it is unfortunately
// necessary:
// 1) We store outputs and inputs in the same vector,
// and outputs cannot be const.
// 2) Most of the codegen is not const-correct (most ops
// require non-const pointers). So even if we put const
// pointers into params, a const_cast would be required
// somewhere else.
params_[param_idx].ptr = const_cast<void*>(src);
}
void SetInput(const void* src, const AITemplateParamShape& shape, size_t idx) {
SetInputShape(shape, idx);
SetParam(src, idx);
}
void SetOutput(void* src, size_t idx) {
SetParam(src, idx + num_inputs_);
}
// Write the (possibly dynamic) output shape to the given pointer.
// Note that this should be called _after_ the shape inference in
// Run() is finished. output_shape_out should be able to store
// at least GetOutputMaximumShape(idx).size values.
void GetOutputShape(size_t idx, int64_t* output_shape_out) {
const auto param_idx = idx + num_inputs_;
CHECK_VECTOR_ACCESS(params_, param_idx);
const auto& shape_ptrs = params_[param_idx].shape_ptrs;
for (size_t i = 0; i < shape_ptrs.size(); ++i) {
output_shape_out[i] = shape_ptrs[i].GetValue();
}
}
void SetConstant(const char* name, const void* src) {
auto it = constant_name_to_ptr_.find(name);
if (it == constant_name_to_ptr_.end()) {
throw std::out_of_range(std::string("Could not find constant ") + name);
}
const void** ptr = it->second;
*ptr = src;
}
private:
void InitConstants(uint8_t* constants) {
params_[0].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(128, 128, &input_0_dim_1)};
params_[1].shape_ptrs = {ParamDim(128, 128, &input_1_dim_0), ParamDim(32, 32, &input_1_dim_1)};
params_[2].shape_ptrs = {ParamDim(256, 256, &input_0_dim_0), ParamDim(32, 32, &input_1_dim_1)};
}
void SetInputShape(const AITemplateParamShape& shape, size_t idx) {
auto& param = params_[idx];
if (shape.size != param.shape_ptrs.size()) {
throw std::runtime_error(
"[SetInputShape] Got wrong param shape for input " + std::to_string(idx) +
"; expected " + std::to_string(param.shape_ptrs.size()) + ", got " +
std::to_string(shape.size));
}
for (size_t i = 0; i < param.shape_ptrs.size(); ++i) {
param.shape_ptrs[i].SetValue(shape.shape_data[i]);
}
}
DeviceError EndCapture(GraphType* graph_ptr) {
auto err = StreamEndCapture(graph_capture_stream_, graph_ptr);
if (err != GetDeviceSuccess()) {
// If we can't take the stream out of capture mode, something is probably
// wrong with CUDA graph for this model (e.g. there might have been an
// illegal capture mode operation). Disable graph mode to avoid such issues
// in future iterations.
target_has_graph_mode = false;
LOG(WARNING) << "Graph capture failed to end. Disabling graph mode.";
return err;
}
return GetDeviceSuccess();
}
void RunAsGraph(StreamType stream) {
DEVICE_CHECK(StreamBeginCapture(graph_capture_stream_, /*global=*/false));
try {
RunImpl(graph_capture_stream_);
} catch (...) {
GraphType graph;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture(&graph);
if (graph != nullptr && GraphDestroy(graph) != GetDeviceSuccess()) {
LOG(WARNING) << "Graph destruction failed while handling exception! Memory will be leaked.";
}
throw;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto graph = RAII_EndCaptureAndCreateGraph(
[this](GraphType* graph_ptr){ return EndCapture(graph_ptr); }
);
if (graph_exec_ == nullptr) {
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
} else if (GraphExecUpdate(graph_exec_, graph.get()) != GetDeviceSuccess()) {
// Consume the last cuda error, which may affect the next GraphExecLaunch
// call.
GetLastError();
DEVICE_CHECK(GraphExecDestroy(graph_exec_));
DEVICE_CHECK(GraphInstantiate(&graph_exec_, graph.get()));
}
DEVICE_CHECK(GraphExecLaunch(graph_exec_, stream));
}
int device_idx_;
int max_smem_size_{0};
DevicePropertyType device_properties_;
// This event tracks when the inference is finished
// so that this Model may be reclaimed by its owning
// ModelContainer.
EventType run_finished_;
// A blob of memory used for storing intermediate tensors.
GPUPtr blob_;
// Memory for constants that were folded into the *.so. Unowned by Model,
// owned by ModelContainer.
// TODO: make this const. It can't be const right now because we derive
// tensor pointers from it, and no tensor pointers are const.
uint8_t* constants_;
size_t num_inputs_;
size_t num_outputs_;
// The workspace blob is used as scratch memory. See
// _generate_workspace in memory planning for more information.
GPUPtr workspace_;
uint8_t* global_workspace_{nullptr};
uint8_t* unique_workspace_{nullptr};
class ParamDim {
public:
ParamDim(int64_t lower_bound, int64_t upper_bound, int64_t* value) :
lower_bound_(lower_bound),
upper_bound_(upper_bound),
value_(value) {}
void SetValue(int64_t new_value) {
if (new_value < lower_bound_ || new_value > upper_bound_) {
throw std::out_of_range(
"[SetValue] Dimension got value out of bounds; expected value to be in [" +
std::to_string(lower_bound_) + ", " + std::to_string(upper_bound_) + "], but got " +
std::to_string(new_value)
);
}
*value_ = new_value;
}
int64_t GetValue() const {
return *value_;
}
private:
int64_t lower_bound_;
int64_t upper_bound_;
int64_t* value_;
};
struct ParamInfo {
void* ptr = nullptr;
// TODO add offset
const char* name;
std::vector<ParamDim> shape_ptrs;
};
// Contains info for all tensors marked as inputs
// or outputs. The first num_inputs elements are the inputs.
// Constants are not included.
std::vector<ParamInfo> params_;
GraphExecType graph_exec_ = nullptr;
StreamType graph_capture_stream_;
std::unordered_map<std::string, const void**> constant_name_to_ptr_;
void * input_0 {nullptr};
void * input_1 {nullptr};
void * output_0 {nullptr};
int64_t input_0_dim_0 { 256 };
int64_t input_0_dim_1 { 128 };
int64_t input_1_dim_0 { 128 };
int64_t input_1_dim_1 { 32 };
};
} // namespace ait
\ No newline at end of file
#include "model_container.h"
#include "device_functions-generated.h"
#include "raii_wrapper.h"
namespace ait {
ModelContainer::ModelContainer(
size_t num_models,
size_t blob_size,
size_t workspace_size,
size_t num_inputs,
size_t num_outputs,
size_t num_unbound_constants,
size_t params_size,
AITemplateAllocator& allocator)
: ModelContainerBase(
num_inputs,
num_outputs,
num_unbound_constants,
params_size,
allocator),
allocator_(allocator),
num_inputs_(num_inputs),
num_outputs_(num_outputs) {
if (num_models == 0) {
throw std::runtime_error("Number of models must be positive");
}
models_.reserve(num_models);
available_models_.reserve(num_models);
for (size_t i = 0; i < num_models; ++i) {
models_.emplace_back(
blob_size,
workspace_size,
num_inputs,
num_outputs,
num_unbound_constants,
static_cast<uint8_t*>(constants_.get()),
allocator);
available_models_.push_back(&models_.back());
}
}
void ModelContainer::Run(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool sync,
bool graph_mode,
int64_t** output_shapes_out) {
auto* model = GetAvailableModel();
try {
PrepareForRun(model, inputs, num_inputs, outputs, num_outputs);
model->Run(stream, graph_mode);
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);
throw;
}
if (output_shapes_out) {
for (size_t i = 0; i < num_outputs; ++i) {
auto* out_shape = output_shapes_out[i];
model->GetOutputShape(i, out_shape);
}
}
{
std::lock_guard lk(models_mutex_);
pending_models_.push_back(model);
}
pending_models_available_.notify_one();
if (sync) {
StreamSynchronize(stream);
}
}
void ModelContainer::RunWithOutputsOnHost(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
int64_t** output_shapes_out) {
std::vector<std::pair<GPUPtr, size_t>> owned_outputs_ptrs;
std::vector<AITData> owned_outputs;
owned_outputs_ptrs.reserve(num_outputs);
owned_outputs.reserve(num_outputs);
for (size_t i = 0; i < num_outputs; ++i) {
size_t num_bytes = MaxOutputStorageBytes(i);
owned_outputs_ptrs.emplace_back(
RAII_DeviceMalloc(num_bytes, allocator_), num_bytes);
owned_outputs.emplace_back(
owned_outputs_ptrs.back().first.get(),
outputs[i].shape,
outputs[i].dtype);
}
Run(inputs,
num_inputs,
owned_outputs.data(),
num_outputs,
stream,
/*sync=*/false,
graph_mode,
output_shapes_out);
for (size_t i = 0; i < num_outputs; ++i) {
auto& owned_output = owned_outputs_ptrs[i];
auto& ptr = owned_output.first;
auto num_bytes = owned_output.second;
DEVICE_CHECK(CopyToHost(outputs[i].ptr, ptr.get(), num_bytes, stream));
}
DEVICE_CHECK(StreamSynchronize(stream));
}
float ModelContainer::Benchmark(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
size_t count,
size_t num_threads,
bool use_unique_stream_per_thread,
int64_t** output_shapes_out) {
if (num_threads == 0) {
num_threads = std::thread::hardware_concurrency();
}
if (num_threads == 1) {
return BenchmarkImpl(
inputs,
num_inputs,
outputs,
num_outputs,
stream,
graph_mode,
count,
output_shapes_out) /
count;
}
// Clone the outputs, each thread needs its own set
std::vector<std::vector<GPUPtr>> per_thread_outputs_ptrs;
std::vector<std::vector<AITData>> per_thread_outputs;
std::vector<StreamPtr> per_thread_streams;
per_thread_outputs_ptrs.reserve(num_threads - 1);
per_thread_outputs.reserve(num_threads - 1);
if (use_unique_stream_per_thread) {
per_thread_streams.reserve(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
per_thread_streams.push_back(RAII_StreamCreate(/*non_blocking=*/true));
}
}
for (size_t i = 1; i < num_threads; ++i) {
std::vector<GPUPtr> cloned_outputs_ptrs;
std::vector<AITData> cloned_outputs;
cloned_outputs_ptrs.reserve(num_outputs);
cloned_outputs.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t num_bytes = MaxOutputStorageBytes(j);
cloned_outputs_ptrs.emplace_back(
RAII_DeviceMalloc(num_bytes, allocator_));
auto* new_pointer = cloned_outputs_ptrs.back().get();
DEVICE_CHECK(
DeviceToDeviceCopy(new_pointer, outputs[j].ptr, num_bytes, stream));
cloned_outputs.emplace_back(
new_pointer, outputs[j].shape, outputs[j].dtype);
}
per_thread_outputs_ptrs.push_back(std::move(cloned_outputs_ptrs));
per_thread_outputs.push_back(std::move(cloned_outputs));
}
DEVICE_CHECK(StreamSynchronize(stream));
auto get_stream = [stream, use_unique_stream_per_thread, &per_thread_streams](
size_t thread_idx) {
if (!use_unique_stream_per_thread) {
return stream;
}
return per_thread_streams[thread_idx].get();
};
auto thread_func = [&](size_t thread_idx) {
AITData* thread_outputs =
thread_idx == 0 ? outputs : per_thread_outputs[thread_idx - 1].data();
StreamType thread_stream = get_stream(thread_idx);
auto* thread_output_shapes_out =
thread_idx == 0 ? output_shapes_out : nullptr;
return BenchmarkImpl(
inputs,
num_inputs,
thread_outputs,
num_outputs,
thread_stream,
graph_mode,
count,
thread_output_shapes_out);
};
std::vector<std::future<float>> futures;
futures.reserve(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
futures.push_back(std::async(std::launch::async, thread_func, i));
}
auto max_time = std::accumulate(
futures.begin(), futures.end(), 0.f, [](float cur_val, auto& future) {
return std::max(future.get(), cur_val);
});
// Verify that all the outputs are the same
for (size_t i = 0; i < num_outputs; ++i) {
auto output_size = MaxOutputStorageBytes(i);
auto output_host = std::make_unique<uint8_t[]>(output_size);
// NB: technically, we don't have to copy to host here, but using
// std::memcmp is easier than writing a kernel that does comparisons
// for both backends, and performance is not important here.
DEVICE_CHECK(
CopyToHost(output_host.get(), outputs[i].ptr, output_size, stream));
DEVICE_CHECK(StreamSynchronize(stream));
for (size_t thread_idx = 1; thread_idx < num_threads; ++thread_idx) {
auto* thread_output = per_thread_outputs[thread_idx - 1][i].ptr;
auto thread_output_host = std::make_unique<uint8_t[]>(output_size);
auto thread_stream = get_stream(thread_idx);
DEVICE_CHECK(CopyToHost(
thread_output_host.get(), thread_output, output_size, thread_stream));
DEVICE_CHECK(StreamSynchronize(thread_stream));
if (std::memcmp(
output_host.get(), thread_output_host.get(), output_size)) {
throw std::runtime_error(
"Output " + std::to_string(i) +
" did not match for a spawned thread!");
}
}
}
auto total_num_iters = num_threads * count;
return max_time / total_num_iters;
}
void ModelContainer::SetConstant(const char* name, const AITData& tensor) {
auto it = unbound_constant_name_to_idx_.find(name);
if (it == unbound_constant_name_to_idx_.end()) {
// TODO make this an exception after we fix the CMF benchmarks
LOG(ERROR) << "Constant " << name << " not found";
return;
}
auto constant_idx = it->second + num_inputs_ + num_outputs_;
ValidateDtype(tensor.dtype, constant_idx);
CHECK_VECTOR_ACCESS(max_param_storage_bytes_, constant_idx)
auto expected_num_bytes = max_param_storage_bytes_[constant_idx];
auto actual_num_bytes =
tensor.shape.Numel() * AITemplateDtypeSizeBytes(tensor.dtype);
if (expected_num_bytes != actual_num_bytes) {
throw std::runtime_error(
std::string(
"SetConstant did not recieve correct number of bytes for constant ") +
name + ": expected " + std::to_string(expected_num_bytes) +
" but got " + std::to_string(actual_num_bytes) +
". Check that the provided tensor's shape is correct.");
}
auto* src = tensor.ptr;
for (auto& model : models_) {
model.SetConstant(name, src);
}
}
size_t ModelContainer::NumInputs() const {
return num_inputs_;
}
const char* ModelContainer::InputName(size_t input_idx) const {
CHECK_VECTOR_ACCESS(param_names_, input_idx)
return param_names_[input_idx];
}
size_t ModelContainer::NumOutputs() const {
return num_outputs_;
}
const char* ModelContainer::OutputName(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(param_names_, idx)
return param_names_[idx];
}
AITemplateParamShape ModelContainer::MaxOutputShape(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(max_param_shapes_, idx)
auto& out_shape = max_param_shapes_[idx];
return AITemplateParamShape{out_shape.data(), out_shape.size()};
}
AITemplateDtype ModelContainer::OutputDtype(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(param_dtypes_, idx)
return param_dtypes_[idx];
}
size_t ModelContainer::MaxOutputStorageBytes(size_t output_idx) const {
auto idx = output_idx + num_inputs_;
CHECK_VECTOR_ACCESS(max_param_storage_bytes_, idx)
return max_param_storage_bytes_[idx];
}
void ModelContainer::PrepareForRun(
Model* model,
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs) {
if (num_inputs != num_inputs_) {
auto msg = "Got wrong number of inputs; expected " +
std::to_string(num_inputs_) + ", got " + std::to_string(num_inputs);
throw std::runtime_error(std::move(msg));
}
if (num_inputs > 0 && inputs == nullptr) {
throw std::runtime_error("inputs cannot be null");
}
if (num_outputs != num_outputs_) {
auto msg = "Got wrong number of outputs; expected " +
std::to_string(num_outputs_) + ", got " + std::to_string(num_outputs);
throw std::runtime_error(std::move(msg));
}
if (num_outputs > 0 && outputs == nullptr) {
throw std::runtime_error("outputs cannot be null");
}
for (size_t i = 0; i < num_inputs_; ++i) {
auto& input = inputs[i];
ValidateDtype(input.dtype, i);
model->SetInput(input.ptr, input.shape, i);
}
for (size_t i = 0; i < num_outputs_; ++i) {
auto& output = outputs[i];
ValidateDtype(output.dtype, i + num_inputs_);
model->SetOutput(output.ptr, i);
}
}
Model* ModelContainer::GetAvailableModel() {
std::unique_lock lk(models_mutex_);
if (available_models_.empty()) {
ReclaimFinishedModels(lk);
}
auto* result = available_models_.back();
available_models_.pop_back();
return result;
}
void ModelContainer::ReclaimFinishedModels(std::unique_lock<std::mutex>& lk) {
// Put any complete models at the end
auto it = std::stable_partition(
pending_models_.begin(), pending_models_.end(), [](Model* m) {
return m->IsPending();
});
if (it != pending_models_.end()) {
// Move all available models to the pool.
available_models_.insert(
available_models_.end(), it, pending_models_.end());
pending_models_.erase(it, pending_models_.end());
return;
}
pending_models_available_.wait(
lk, [this]() { return !pending_models_.empty(); });
// There are no available workspaces! We have to wait on one.
auto* model = pending_models_.front();
pending_models_.pop_front();
lk.unlock();
try {
model->WaitForCompletion();
} catch (...) {
lk.lock();
available_models_.push_back(model);
throw;
}
lk.lock();
available_models_.push_back(model);
}
void ModelContainer::ValidateDtype(AITemplateDtype dtype, size_t idx) const {
CHECK_VECTOR_ACCESS(param_dtypes_, idx)
if (dtype != param_dtypes_[idx]) {
auto GetEnumString = [](auto dtype) {
switch (dtype) {
case AITemplateDtype::kUnset:
return "kUnset";
case AITemplateDtype::kHalf:
return "kHalf";
case AITemplateDtype::kFloat:
return "kFloat";
case AITemplateDtype::kInt:
return "kInt";
case AITemplateDtype::kLong:
return "kLong";
default:
return "unknown";
}
};
throw std::runtime_error(
"Got wrong dtype for param " + std::to_string(idx) + "; expected " +
GetEnumString(param_dtypes_[idx]) + ", got " + GetEnumString(dtype));
}
}
float ModelContainer::BenchmarkImpl(
const AITData* inputs,
size_t num_inputs,
AITData* outputs,
size_t num_outputs,
StreamType stream,
bool graph_mode,
size_t count,
int64_t** output_shapes_out) {
auto* model = GetAvailableModel();
float runtime_ms = 0.;
auto start_event = RAII_CreateEvent();
auto end_event = RAII_CreateEvent();
try {
PrepareForRun(model, inputs, num_inputs, outputs, num_outputs);
DEVICE_CHECK(EventRecord(start_event.get(), stream));
for (size_t i = 0; i < count; ++i) {
model->Run(stream, graph_mode);
}
} catch (...) {
std::lock_guard lk(models_mutex_);
available_models_.push_back(model);
throw;
}
if (output_shapes_out) {
for (size_t i = 0; i < num_outputs; ++i) {
auto* out_shape = output_shapes_out[i];
model->GetOutputShape(i, out_shape);
}
}
// Push the model back into the pool before synchronizing the event
// to exercise the concurrency code
{
std::lock_guard lk(models_mutex_);
pending_models_.push_back(model);
}
pending_models_available_.notify_one();
DEVICE_CHECK(EventRecord(end_event.get(), stream));
DEVICE_CHECK(EventSynchronize(end_event.get()));
DEVICE_CHECK(
EventElapsedTime(&runtime_ms, start_event.get(), end_event.get()));
LOG(INFO) << "Benchmark runtime ms/iter: " << runtime_ms / count;
return runtime_ms;
}
} // namespace ait
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment