Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
516bbdcb
Commit
516bbdcb
authored
Apr 03, 2023
by
Astha Rai
Browse files
initial push with templated dev op
parent
fde6d274
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6949 additions
and
0 deletions
+6949
-0
python/AIT implementation/generation/Makefile
python/AIT implementation/generation/Makefile
+8
-0
python/AIT implementation/generation/Makefile2
python/AIT implementation/generation/Makefile2
+16
-0
python/AIT implementation/generation/demo.py
python/AIT implementation/generation/demo.py
+19
-0
python/AIT implementation/generation/gemm_dev_op.py
python/AIT implementation/generation/gemm_dev_op.py
+665
-0
python/AIT implementation/generation/gemm_kernel.py
python/AIT implementation/generation/gemm_kernel.py
+175
-0
python/AIT implementation/generation/norm_ex.py
python/AIT implementation/generation/norm_ex.py
+282
-0
python/AIT implementation/generation/permute_ex.py
python/AIT implementation/generation/permute_ex.py
+43
-0
python/AIT implementation/generation/xx.cpp
python/AIT implementation/generation/xx.cpp
+588
-0
python/AIT implementation/sample files/ck/device_gemm_dl_ck.hpp
.../AIT implementation/sample files/ck/device_gemm_dl_ck.hpp
+595
-0
python/AIT implementation/sample files/ck/gridwise_gemm_dl_v1r3_ck.hpp
...plementation/sample files/ck/gridwise_gemm_dl_v1r3_ck.hpp
+577
-0
python/AIT implementation/sample files/elementwise_common.py
python/AIT implementation/sample files/elementwise_common.py
+855
-0
python/AIT implementation/sample files/gemm_common.py
python/AIT implementation/sample files/gemm_common.py
+969
-0
python/AIT implementation/sample files/gemm_rrr.cpp
python/AIT implementation/sample files/gemm_rrr.cpp
+165
-0
python/AIT implementation/sample files/gemm_rrr_3.cpp
python/AIT implementation/sample files/gemm_rrr_3.cpp
+165
-0
python/AIT implementation/sample files/layernorm.cpp
python/AIT implementation/sample files/layernorm.cpp
+277
-0
python/AIT implementation/sample files/model-generate.h
python/AIT implementation/sample files/model-generate.h
+419
-0
python/AIT implementation/sample files/model-generated.h
python/AIT implementation/sample files/model-generated.h
+419
-0
python/AIT implementation/sample files/model_container.cpp
python/AIT implementation/sample files/model_container.cpp
+465
-0
python/AIT implementation/sample files/model_container.h
python/AIT implementation/sample files/model_container.h
+187
-0
python/AIT implementation/sample files/model_container_base.cpp
.../AIT implementation/sample files/model_container_base.cpp
+60
-0
No files found.
python/AIT implementation/generation/Makefile
0 → 100644
View file @
516bbdcb
gemm
:
xx.o
CFLAGS
=
-I
~/rocm/composable_kernel/include
-I
/opt/rocm-5.1.1/hip/include
-I
~/rocm/composable_kernel/include/
-I
~/rocm/composable_kernel/include/ck/
-I
~/rocm/composable_kernel/include/ck/problem_transform/
-I
~/rocm/composable_kernel/include/ck/tensor/
-I
~/rocm/composable_kernel/include/ck/tensor_description/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/block/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/device/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/device/impl/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/element/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/grid/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/thread/
-I
~/rocm/composable_kernel/include/ck/tensor_operation/gpu/warp/
-I
~/rocm/composable_kernel/include/ck/host_utility
-I
/external/include/half/
-I
~/rocm/composable_kernel/library/include/ck/library/host/
-I
~/rocm/composable_kernel/library/include/ck/library/host_tensor/
-I
~/rocm/composable_kernel/library/include/ck/library/obselete_driver_offline/
-I
~/rocm/composable_kernel/library/include/ck/library/reference_tensor_operation/cpu/
-I
~/rocm/composable_kernel/library/include/ck/library/reference_tensor_operation/gpu/
-I
~/rocm/composable_kernel/library/include/ck/library/tensor_operation_instance/
-I
~/rocm/composable_kernel/library/include/ck/library/tensor_operation_instance/gpu/
" + "
reduce/
-I
~/rocm/composable_kernel/library/include/ck/library/tensor_op/
-I
~/rocm/composable_kernel/library/include/ck/library/utility/
-I
~/rocm/composable_kernel/profiler/include/
CXXFLAGS
=
-std
=
c++17
xx.o
:
hipcc
-c
$(CFLAGS)
$(CXXFLAGS)
xx.cpp
python/AIT implementation/generation/Makefile2
0 → 100644
View file @
516bbdcb
CC = {{cc}}
CFLAGS = {{CFLAGS}}
fPIC_flag = {{fPIC}}
obj_files = {{obj_files}}
%.obj : %.{{cpp}}
{{cfile_cmd}}
%.obj : %.bin
{{bfile_cmd}}
.PHONY: all clean clean_constants
all: {{target}}
{{target}}: $(obj_files)
$(CC) -shared $(fPIC_flag) $(CFLAGS) -o $@ $(obj_files)
clean:
rm -f *.obj {{target}} test.so
clean_constants:
rm -f constants.bin
\ No newline at end of file
python/AIT implementation/generation/demo.py
0 → 100644
View file @
516bbdcb
import
jinja2
SHAPE_EVAL_TEMPLATE
=
jinja2
.
Template
(
"""
int M = *in_{{ range(rank - 1)|join(' * *in_') }};
int N = *in_{{rank - 1}};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{};
"""
)
output
=
SHAPE_EVAL_TEMPLATE
.
render
(
rank
=
2
);
print
(
output
)
\ No newline at end of file
python/AIT implementation/generation/gemm_dev_op.py
0 → 100644
View file @
516bbdcb
import
enum
import
os.path
import
shutil
import
functools
import
operator
import
collections
import
re
def
SubstituteTemplate
(
template
,
values
):
text
=
template
changed
=
True
while
changed
:
changed
=
False
for
key
,
value
in
values
.
items
():
regex
=
"
\\
$
\\
{%s
\\
}"
%
key
newtext
=
re
.
sub
(
regex
,
value
,
text
)
if
newtext
!=
text
:
changed
=
True
text
=
newtext
return
text
class
EmitGemmInstance
:
def
__init__
(
self
):
self
.
gemm_devop_template
=
"""
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ${type_a},
typename ${type_b},
typename ${type_c},
typename ${type_acc},
typename ${layout_a},
typename ${layout_b},
typename ${layout_c},
typename ${elementwise_op_a},
typename ${elementwise_op_b},
typename ${elementwise_op_c},
${Gemm_spec},
${block_size},
${mperblock},
${nperblock},
${k0perblock},
${k1},
${m1perthread},
${n1perthread},
${kperthread},
typename ${m1n1_thcluster_m1xs},
typename ${m1n1_thcluster_n1xs},
typename ${ABT_thread_slice_lengths_K0_M0_M1_K1},
typename ${ABT_thread_cluster_lengths_K0_M0_M1_K1},
typename ${ABT_thread_cluster_arrange_order},
typename ${ABT_src_access_order},
typename ${ABT_src_vec_tensor_lengths_K0_M0_M1_K1},
typename ${ABT_src_vec_tensor_cont_dim_order},
typename ${ABT_dst_vec_tensor_lengths_K0_M0_M1_K1},
typename ${BBT_thread_slice_lengths_K0_N0_N1_K1},
typename ${BBT_thread_cluster_lengths_K0_N0_N1_K1},
typename ${BBT_thread_cluster_arrange_order},
typename ${BBT_src_access_order},
typename ${BBT_src_vec_tensor_lengths_K0_N0_N1_K1},
typename ${BBT_src_vec_tensor_cont_dim_order},
typename ${BBT_dst_vec_tensor_lengths_K0_N0_N1_K1},
typename ${CTT_src_dst_access_order},
${CTT_src_dst_vec_dim},
${CTT_dst_scalar_per_vector}>
struct DeviceGemmDl : public DeviceGemm<${layout_a},
${layout_b},
${layout_c},
${type_a},
${type_b},
${type_c},
${elementwise_op_a},
${elementwise_op_b},
${elementwise_op_c}>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<${k1}>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
{
assert(K % ${k1} == 0);
const index_t K0 = K / ${k1};
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ${layout_a}>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ${layout_a}>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(${Gemm_spec} == GemmSpecialization::MNPadding)
{
const auto PadM = (${mperblock} - M % ${mperblock}) % ${mperblock};
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
{
assert(K % ${k1} == 0);
const index_t K0 = K / ${k1};
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ${layout_b}>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ${layout_B}>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(${Gemm_spec} == GemmSpecialization::MNPadding)
{
const auto PadN = (${nperblock} - N % ${nperblock}) % ${nperblock};
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ${layout_c}>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ${layout_c}>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(${Gemm_spec} == GemmSpecialization::MNPadding)
{
const auto PadM = (${mperblock} - M % ${mperblock}) % ${mperblock};
const auto PadN = (${nperblock} - N % ${nperblock}) % ${nperblock};
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
${type_a},
${type_acc},
${type_c},
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
${mperblock},
${nperblock},
${k0perblock},
${k1},
${m1perthread},
${n1perthread},
${kperthread},
${m1n1_thcluster_m1xs},
${m1n1_thcluster_n1xs},
${ABT_thread_slice_lengths_K0_M0_M1_K1},
${ABT_thread_cluster_lengths_K0_M0_M1_K1},
${ABT_thread_cluster_arrange_order},
${ABT_src_access_order},
${ABT_src_vec_tensor_lengths_K0_M0_M1_K1},
${ABT_src_vec_tensor_cont_dim_order},
${ABT_dst_vec_tensor_lengths_K0_M0_M1_K1},
${BBT_thread_slice_lengths_K0_N0_N1_K1},
${BBT_thread_cluster_lengths_K0_N0_N1_K1},
${BBT_thread_cluster_arrange_order},
${BBT_src_access_order},
${BBT_src_vec_tensor_lengths_K0_N0_N1_K1},
${BBT_src_vec_tensor_cont_dim_order},
${BBT_dst_vec_tensor_lengths_K0_N0_N1_K1},
${CTT_src_dst_access_order},
${CTT_src_dst_vec_dim},
${CTT_dst_scalar_per_vector}>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(const ${type_a}* p_a_grid,
const ${type_b}* p_b_grid,
${type_c}* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t M01,
index_t N01,
${elementwise_op_a} a_element_op,
${elementwise_op_b} b_element_op,
${elementwise_op_c} c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{
a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
}
}
// private:
const ${type_a}* p_a_grid_;
const ${type_b}* p_b_grid_;
${type_c}* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused, but may be useful in future.
index_t M01_;
index_t N01_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
${elementwise_op_a} a_element_op_;
${elementwise_op_b} b_element_op_;
${elementwise_op_c} c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceGemmDl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
}
const index_t grid_size = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
${type_a},
${type_c},
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(${block_size}),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
${type_a},
${type_c},
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(${block_size}),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
${type_a},
${type_c},
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(${block_size}),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
${type_a},
${type_c},
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(${block_size}),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
}
else
{
return false;
}
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ${type_a}* p_a,
const ${type_b}* p_b,
${type_c}* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
${elementwise_op_a} a_element_op,
${elementwise_op_b} b_element_op,
${elementwise_op_c} c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
${elementwise_op_a} a_element_op,
${elementwise_op_b} b_element_op,
${elementwise_op_c} c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ${type_a}*>(p_a),
static_cast<const ${type_b}*>(p_b),
static_cast<${type_c}*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDl"
<< "<"
<< ${block_size} << ", "
<< ${mperblock} << ", "
<< ${nperblock} << ", "
<< ${k0perblock} << ", "
<< ${k1} << ", "
<< ${m1perthread} << ", "
<< ${n1perthread} << ", "
<< ${kperthread}
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
"""
def
emit
(
self
):
values
=
{
'type_a'
:
'ck::half_t'
,
'type_b'
:
'ck::half_t'
,
'type_c'
:
'ck::half_t'
,
'type_acc'
:
'float'
,
'layout_a'
:
'ck::tensor_layout::gemm::ColMajor'
,
'layout_b'
:
'ck::tensor_layout::gemm::RowMajor'
,
'layout_c'
:
'ck::tensor_layout::gemm::RowMajor'
,
'elementwise_op_a'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'elementwise_op_b'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'elementwise_op_c'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'Gemm_spec'
:
'ck::tensor_operation::device::GemmSpecialization::Default'
,
'block_size'
:
'256'
,
'mperblock'
:
'128'
,
'nperblock'
:
'128'
,
'k0perblock'
:
'16'
,
'k1'
:
'2'
,
'm1perthread'
:
'4'
,
'n1perthread'
:
'4'
,
'kperthread'
:
'1'
,
'm1n1_thcluster_m1xs'
:
'S<8, 2>'
,
'm1n1_thcluster_n1xs'
:
'S<8, 2>'
,
'ABT_thread_slice_lengths_K0_M0_M1_K1'
:
'S<2, 1, 4, 2>'
,
'ABT_thread_cluster_lengths_K0_M0_M1_K1'
:
'S<8, 1, 32, 1>'
,
'ABT_thread_cluster_arrange_order'
:
'S<0, 3, 1, 2>'
,
'ABT_src_access_order'
:
'S<0, 3, 1, 2>'
,
'ABT_src_vec_tensor_lengths_K0_M0_M1_K1'
:
'S<1, 1, 4, 1>'
,
'ABT_src_vec_tensor_cont_dim_order'
:
'S<0, 3, 1, 2>'
,
'ABT_dst_vec_tensor_lengths_K0_M0_M1_K1'
:
'S<1, 1, 4, 2>'
,
'BBT_thread_slice_lengths_K0_N0_N1_K1'
:
'S<2, 1, 4, 2>'
,
'BBT_thread_cluster_lengths_K0_N0_N1_K1'
:
'S<8, 1, 32, 1>'
,
'BBT_thread_cluster_arrange_order'
:
'S<0, 3, 1, 2>'
,
'BBT_src_access_order'
:
'S<0, 3, 1, 2>'
,
'BBT_src_vec_tensor_lengths_K0_N0_N1_K1'
:
'S<1, 1, 4, 1>'
,
'BBT_src_vec_tensor_cont_dim_order'
:
'S<0, 3, 1, 2>'
,
'BBT_dst_vec_tensor_lengths_K0_N0_N1_K1'
:
'S<1, 1, 4, 2>'
,
'CTT_src_dst_access_order'
:
'S<0, 1, 2, 3, 4, 5>'
,
'CTT_src_dst_vec_dim'
:
'5'
,
'CTT_dst_scalar_per_vector'
:
'4'
}
template
=
self
.
gemm_devop_template
cf
=
open
(
"xx.cpp"
,
'w'
)
print
(
SubstituteTemplate
(
template
,
values
))
cf
.
write
(
SubstituteTemplate
(
template
,
values
))
cf
.
close
()
a
=
EmitGemmInstance
()
a
.
emit
()
python/AIT implementation/generation/gemm_kernel.py
0 → 100644
View file @
516bbdcb
import
enum
import
os.path
import
shutil
import
functools
import
operator
import
collections
import
re
def
SubstituteTemplate
(
template
,
values
):
text
=
template
changed
=
True
while
changed
:
changed
=
False
for
key
,
value
in
values
.
items
():
regex
=
"
\\
$
\\
{%s
\\
}"
%
key
newtext
=
re
.
sub
(
regex
,
value
,
text
)
if
newtext
!=
text
:
changed
=
True
text
=
newtext
return
text
class
EmitGemmInstance
:
def
__init__
(
self
):
self
.
gemm_kernel_template
=
"""
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename Block2CTileMap,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_v1r3(const ${type_ab}* __restrict__ ${p_a_grid},
const ${type_ab}* __restrict__ ${p_b_grid},
${type_c}* __restrict__ ${p_c_grid},
const ${A_GridDesc_K0_M_K1} ${a_grid_desc_k0_m0_m1_k1},
const ${BGridDesc_K0_N_K1} ${b_grid_desc_k0_n0_n1_k1},
const ${CGridDesc_M0_M10_M11_N0_N10_N11} ${c_grid_desc_m0_m10_m11_n0_n10_n11},
const Block2CTileMap ${block_2_ctile_map})
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(${type_ab});
__shared__ ${type_ab} p_shared_block[shared_block_size];
GridwiseGemm::Run(${p_a_grid},
${p_b_grid},
${p_c_grid},
p_shared_block,
${a_grid_desc_k0_m0_m1_k1},
${b_grid_desc_k0_n0_n1_k1},
${c_grid_desc_m0_m10_m11_n0_n10_n11},
${block_2_ctile_map},
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
${type_ab},
${type_acc},
${type_c},
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
${A_GridDesc_K0_M_K1},
${BGridDesc_K0_N_K1},
${CGridDesc_M_N},
${mperblock},
${nperblock},
${k0perblock},
${k1value},
${M1PerThreadM111},
${N1PerThreadN111},
${KPerThread},
${M11N11ThreadClusterM110Xs},
${M11N11ThreadClusterN110Xs},
${ABlockTransferThreadSliceLengths_K0_M0_M1_K1},
${ABlockTransferThreadClusterLengths_K0_M0_M1_K1},
${ABlockTransferThreadClusterArrangeOrder},
${ABlockTransferSrcAccessOrder},
${ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1},
${ABlockTransferSrcVectorTensorContiguousDimOrder},
${ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1},
${BBlockTransferThreadSliceLengths_K0_N0_N1_K1},
${BBlockTransferThreadClusterLengths_K0_N0_N1_K1},
${BBlockTransferThreadClusterArrangeOrder},
${BBlockTransferSrcAccessOrder},
${BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1},
${BBlockTransferSrcVectorTensorContiguousDimOrder},
${BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1},
${CThreadTransferSrcDstAccessOrder},
${CThreadTransferSrcDstVectorDim},
${CThreadTransferDstScalarPerVector}>
struct GridwiseGemmDl_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
"""
def
emit
(
self
):
values
=
{
'function_name'
:
"gemm"
,
'type_a'
:
'ck::half_t'
,
'type_b'
:
'ck::half_t'
,
'type_c'
:
'ck::half_t'
,
'type_acc'
:
'float'
,
'layout_a'
:
'ck::tensor_layout::gemm::RowMajor'
,
'layout_b'
:
'ck::tensor_layout::gemm::RowMajor'
,
'layout_c'
:
'ck::tensor_layout::gemm::RowMajor'
,
'elementwise_op_a'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'elementwise_op_b'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'elementwise_op_c'
:
'ck::tensor_operation::element_wise::PassThrough'
,
'Gemm_spec'
:
'ck::tensor_operation::device::GemmSpecialization::MNKPadding'
,
'block_size'
:
'256'
,
'mperblock'
:
'64'
,
'nperblock'
:
'128'
,
'kperblock'
:
'32'
,
'k1'
:
'8'
,
'mperxdl'
:
'32'
,
'nperxdl'
:
'32'
,
'mxdlperwave'
:
'1'
,
'nxdlperwave'
:
'2'
,
'threadclusterlength_a'
:
'ck::Sequence<4,64,1>'
,
'threadclusterarrange_a'
:
'ck::Sequence<1,0,2>'
,
'srcaccessorder_a'
:
'ck::Sequence<1,0,2>'
,
'srcvectordim_a'
:
'2'
,
'srcscalarpervec_a'
:
'8'
,
'dstscalarpervec_a'
:
'8'
,
'add_extra_dim_a'
:
'1'
,
'threadclusterlength_b'
:
'ck::Sequence<8,32,1>'
,
'threadclusterarrange_b'
:
'ck::Sequence<0,2,1>'
,
'srcaccessorder_b'
:
'ck::Sequence<0,2,1>'
,
'srcvectordim_b'
:
'1'
,
'srcscalarpervec_b'
:
'4'
,
'dstscalarpervec_b'
:
'2'
,
'add_extra_dim_b'
:
'0'
,
'dstscalarpervec_c'
:
'8'
}
template
=
self
.
gemm_template
print
(
SubstituteTemplate
(
template
,
values
))
\ No newline at end of file
python/AIT implementation/generation/norm_ex.py
0 → 100644
View file @
516bbdcb
import
os
import
re
from
hashlib
import
sha1
from
typing
import
Any
,
Dict
,
OrderedDict
import
jinja2
#from ...target import Target
#templating
FUNC_CALL_PARAM_TEMPLATE
=
jinja2
.
Template
(
"(void *)({{name}})"
)
INSTANCE_TEMPLATE
=
jinja2
.
Template
(
"""
using {{name}} = {{ config_name }};
"""
)
ARGS_PARSE_TEMPLATE
=
jinja2
.
Template
(
"""
{% for idx in range(rank) %}
const int64_t in_{{idx}} = std::stoi(argv[{{ idx + 1 }}]);
{% endfor %}
"""
)
STRUCTS_DEF_TEMPLATE
=
jinja2
.
Template
(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
FUNC_TEMPLATE
=
jinja2
.
Template
(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
{{extra_headers}}
{{extra_code}}
{{instances_decl}}
{{func_signature}}
{
{{shape_eval}}
{{exec_paths}}
}
"""
)
FUNC_CALL_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}}{{func_name}}(
{{indent}} {{input}},
{{indent}} {{output}},
{% for name in input_dim_names %}
{{indent}} const_cast<int64_t *>(&{{name}}),
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROFILER_TEMPLATE
=
jinja2
.
Template
(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
KernelTimerImpl timer;
timer.Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer.End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer.GetElapsedTime() << std::endl;
}
"""
)
# rendering (messy, need to modularize and organize)
# def gen_profiler(
# shape_eval_template: jinja2.Template,
# exec_template: jinja2.Template,
# tensor_decl_template: jinja2.Template,
# extra_header_template: jinja2.Template,
# get_func_signature: Any,
# extra_code: str = "",
# func_call_template: jinja2.Template = FUNC_CALL_TEMPLATE,
# indent: str = " ",
# ) -> str:
# shape_eval_template: jinja2.Template
# exec_template: jinja2.Template
# tensor_decl_template: jinja2.Template
#extra_header_template: jinja2.Template
get_func_signature
:
Any
extra_code
:
str
=
""
func_call_template
:
jinja2
.
Template
=
FUNC_CALL_TEMPLATE
indent
:
str
=
" "
# shape_eval = shape_eval_template.render(rank=2) #if shape_eval_template else ""
# exe_path = exec_template.render(instance="DeviceInstance",dtype="void",reduce_dims=1,rank=2,eps=eps,)
instances
=
INSTANCE_TEMPLATE
.
render
(
name
=
"DeviceInstance"
,
config_name
=
"ck::tensor_operation::device::DeviceLayernormImpl"
,)
op_func
=
FUNC_TEMPLATE
.
render
(
instances_decl
=
instances
,
#func_signature=get_func_signature(func_attrs),
#shape_eval=shape_eval,
#exec_paths=exe_path,
#extra_headers=extra_header_template.render(),
extra_code
=
extra_code
,)
structs_def
=
STRUCTS_DEF_TEMPLATE
.
render
()
args_parse
=
ARGS_PARSE_TEMPLATE
.
render
(
rank
=
2
)
#tensor_decl = tensor_decl_template.render(rank=2)
input_dim_names
=
[
f
"in_
{
i
}
"
for
i
in
range
(
2
)]
func_call
=
func_call_template
.
render
(
func_name
=
"norm"
,
input
=
"(void *) memory_pool->RequestHalfTensorByIdx(0)"
,
gamma
=
"(void *) memory_pool->RequestHalfTensorByIdx(2)"
,
beta
=
"(void *) memory_pool->RequestHalfTensorByIdx(3)"
,
output
=
"(void *) memory_pool->RequestHalfTensorByIdx(1)"
,
input_dim_names
=
input_dim_names
,
indent
=
indent
,
)
code
=
PROFILER_TEMPLATE
.
render
(
op_func
=
op_func
,
structs_def
=
structs_def
,
args_parse
=
args_parse
,
#tensor_decl=tensor_decl,
func_call
=
func_call
,
)
# print(instances)
# print(args_parse)
# print(structs_def)
#print(func_call)
#print(op_func)
print
(
code
)
python/AIT implementation/generation/permute_ex.py
0 → 100644
View file @
516bbdcb
import
jinja2
EXTRA_SHAPE_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
ck::index_t M0 = M / G1 / G2;
ck::index_t M1 = G1;
ck::index_t M2 = G2;
ck::index_t N0 = G3;
ck::index_t N1 = N / G3;
// GEMM shape
//ck::index_t M = M0 * M1 * M2;
//ck::index_t N = N0 * N1;
//ck::index_t K = 128;
//ck::index_t stride_A = K;
//ck::index_t stride_B = K;
// E = [M0, N0, M1, N1, M2]
/* 0, 3, 1, 4, 2
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
*/
// E = [M2, M0, N0, M1, N1] 2, 0, 3, 1, 4
ck::index_t stride_E_M0 = N0* M1* N1;
ck::index_t stride_E_M1 = N1;
ck::index_t stride_E_M2 = M0* N0* M1* N1;
ck::index_t stride_E_N0 = M1 * N1;
ck::index_t stride_E_N1 = 1;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
"""
)
output
=
EXTRA_SHAPE_TEMPLATE
.
render
(
indent
=
" "
);
print
(
output
)
\ No newline at end of file
python/AIT implementation/generation/xx.cpp
0 → 100644
View file @
516bbdcb
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ck
::
half_t
,
typename
ck
::
half_t
,
typename
ck
::
half_t
,
typename
float
,
typename
ck
::
tensor_layout
::
gemm
::
ColMajor
,
typename
ck
::
tensor_layout
::
gemm
::
RowMajor
,
typename
ck
::
tensor_layout
::
gemm
::
RowMajor
,
typename
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
typename
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
typename
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
typename
S
<
8
,
2
>,
typename
S
<
8
,
2
>
,
typename
S
<
2
,
1
,
4
,
2
>
,
typename
S
<
8
,
1
,
32
,
1
>
,
typename
S
<
0
,
3
,
1
,
2
>
,
typename
S
<
0
,
3
,
1
,
2
>
,
typename
S
<
1
,
1
,
4
,
1
>
,
typename
S
<
0
,
3
,
1
,
2
>
,
typename
S
<
1
,
1
,
4
,
2
>
,
typename
S
<
2
,
1
,
4
,
2
>
,
typename
S
<
8
,
1
,
32
,
1
>
,
typename
S
<
0
,
3
,
1
,
2
>
,
typename
S
<
0
,
3
,
1
,
2
>
,
typename
S
<
1
,
1
,
4
,
1
>
,
typename
S
<
0
,
3
,
1
,
2
>
,
typename
S
<
1
,
1
,
4
,
2
>
,
typename
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
struct
DeviceGemmDl
:
public
DeviceGemm
<
ck
::
tensor_layout
::
gemm
::
ColMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
K1Number
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
2
==
0
);
const
index_t
K0
=
K
/
2
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColMajor
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColMajor
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
128
-
M
%
128
)
%
128
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
2
==
0
);
const
index_t
K0
=
K
/
2
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
$
{
layout_B
}
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
128
-
N
%
128
)
%
128
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
128
-
M
%
128
)
%
128
;
const
auto
PadN
=
(
128
-
N
%
128
)
%
128
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ck
::
half_t
,
float
,
ck
::
half_t
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
2
,
1
,
4
,
2
>
,
S
<
8
,
1
,
32
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
;
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ck
::
half_t
*
p_a_grid
,
const
ck
::
half_t
*
p_b_grid
,
ck
::
half_t
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
a_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
b_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
c_grid_desc_m0_m10_m11_n0_n10_n11_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmDl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmDl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmDl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
))
{
a_grid_desc_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1_
);
b_grid_desc_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
}
// private:
const
ck
::
half_t
*
p_a_grid_
;
const
ck
::
half_t
*
p_b_grid_
;
ck
::
half_t
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// TODO: unused, but may be useful in future.
index_t
M01_
;
index_t
N01_
;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
ck
::
tensor_operation
::
element_wise
::
PassThrough
a_element_op_
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
b_element_op_
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmDl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m0_m1_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n0_n1_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
),
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
));
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ck
::
half_t
,
ck
::
half_t
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
256
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ck
::
half_t
,
ck
::
half_t
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
256
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ck
::
half_t
,
ck
::
half_t
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
256
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ck
::
half_t
,
ck
::
half_t
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
256
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
}
else
{
return
false
;
}
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ck
::
half_t
*
p_a
,
const
ck
::
half_t
*
p_b
,
ck
::
half_t
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
a_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
b_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
a_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
b_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ck
::
half_t
*>
(
p_a
),
static_cast
<
const
ck
::
half_t
*>
(
p_b
),
static_cast
<
ck
::
half_t
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmDl"
<<
"<"
<<
256
<<
", "
<<
128
<<
", "
<<
128
<<
", "
<<
16
<<
", "
<<
2
<<
", "
<<
4
<<
", "
<<
4
<<
", "
<<
1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
\ No newline at end of file
python/AIT implementation/sample files/ck/device_gemm_dl_ck.hpp
0 → 100644
View file @
516bbdcb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
typename
M1N1ThreadClusterM1Xs
,
typename
M1N1ThreadClusterN1Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
enable_if_t
<
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
CElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
bool
>
=
false
>
struct
DeviceGemmDl
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K1
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
c_grid_desc_m0_m10_m11_n0_n10_n11_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmDl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmDl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmDl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
))
{
a_grid_desc_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1_
);
b_grid_desc_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
// TODO: unused, but may be useful in future.
index_t
M01_
;
index_t
N01_
;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmDl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m0_m1_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n0_n1_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
),
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
));
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
remove_reference_t
<
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
}
else
{
return
false
;
}
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmDl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
M1PerThread
<<
", "
<<
N1PerThread
<<
", "
<<
KPerThread
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
\ No newline at end of file
python/AIT implementation/sample files/ck/gridwise_gemm_dl_v1r3_ck.hpp
0 → 100644
View file @
516bbdcb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dl_v1r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
K1Value
,
index_t
M1PerThreadM111
,
index_t
N1PerThreadN111
,
index_t
KPerThread
,
typename
M11N11ThreadClusterM110Xs
,
typename
M11N11ThreadClusterN110Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemmDl_km_kn_mn_v1r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K0
)
{
const
bool
has_main_k_block_loop
=
(
K0
+
K0PerBlock
)
/
(
2
*
K0PerBlock
)
>
1
;
return
has_main_k_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailKBlockLoop
(
index_t
K0
)
{
const
bool
has_double_tail_k_block_loop
=
(
K0
/
K0PerBlock
)
%
2
==
0
;
return
has_double_tail_k_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M0_M1_K1
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
)
{
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
M1
=
Number
<
MPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
a_grid_desc_k0_m0_m1_k1
=
transform_tensor_descriptor
(
a_grid_desc_k0_m_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_grid_desc_k0_m0_m1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N0_N1_K1
(
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
)
{
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
N0
=
N
/
N1
;
const
auto
b_grid_desc_k0_n0_n1_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_grid_desc_k0_n0_n1_k1
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
constexpr
auto
M11
=
Number
<
container_reduce
(
M11N11ThreadClusterM110Xs
{},
math
::
multiplies
{},
I1
)
*
M1PerThreadM111
>
{};
constexpr
auto
N11
=
Number
<
container_reduce
(
M11N11ThreadClusterN110Xs
{},
math
::
multiplies
{},
I1
)
*
N1PerThreadN111
>
{};
constexpr
auto
M10
=
M1
/
M11
;
constexpr
auto
N10
=
N1
/
N11
;
const
auto
c_grid_desc_m0_m10_m11_n0_n10_n11
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M10
,
M11
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N10
,
N11
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
c_grid_desc_m0_m10_m11_n0_n10_n11
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
&
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
{
return
;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
()
==
b_k0_n_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
MPerBlock
,
K1
.
value
>
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
a_grid_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
im0
,
0
,
0
),
a_block_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
NPerBlock
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
// SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
// DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
b_grid_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
0
,
0
),
b_block_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const
auto
blockwise_gemm
=
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
M1PerThreadM111
,
N1PerThreadN111
,
KPerThread
,
M11N11ThreadClusterM110Xs
,
M11N11ThreadClusterN110Xs
,
M1PerThreadM111
,
N1PerThreadN111
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_m10_m11_n10_n11
=
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_m10_m11_n10_n11
.
GetElementSpaceSize
());
// Initialize C
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
block_sync_lds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
c_thread_desc_m10_m11_n10_n11
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_m10_m11_n0_n10_n11
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}));
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
decltype
(
c_grid_desc_m0_m10_m11_n0_n10_n11
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_m10_m11_n0_n10_n11
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_buf
);
}
}
};
}
// namespace ck
\ No newline at end of file
python/AIT implementation/sample files/elementwise_common.py
0 → 100644
View file @
516bbdcb
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Backend-agnostic functions for elementwise codegen.
"""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
jinja2
# from aitemplate.backend.backend_spec import BackendSpec
# from ...compiler.base import IntImm, IntVar, Operator, Tensor
# from ...compiler.tensor_accessor import TensorAccessor
# from ...utils import shape_utils
#from . import tensor_accessor_codegen
CONSTANT_TEMPLATE
=
jinja2
.
Template
(
"""
#define FUSED_ELE_THREAD_SIZE 256
const int N_ELEMENTS_PER_THREAD = sizeof({{read_t}}) / sizeof({{data_t}});
const int N_ELEMENTS_PER_READ = sizeof({{read_t}}) / sizeof({{data_t}});
const int N_OPS_PER_THREAD = sizeof({{read_t}}) / sizeof({{op_t}});
"""
)
KERNEL_DECL_INPUT_PARAM_TEMPLATE
=
jinja2
.
Template
(
"const {{read_t}}* input{{idx}}"
)
KERNEL_DECL_OUTPUT_PARAM_TEMPLATE
=
jinja2
.
Template
(
"{{read_t}}* output{{idx}}"
)
KERNEL_TMP_INPUT_TEMPLATE
=
jinja2
.
Template
(
"p_tmp_i{{idx}}[i]"
)
KERNEL_TMP_OUTPUT_TEMPLATE
=
jinja2
.
Template
(
"p_tmp_o{{idx}}[i]"
)
GET_STRIDED_ADDRESS_TEMPLATE
=
jinja2
.
Template
(
"""
{% if tensor_accessor.is_contiguous %}
{{data_ptr}} = get_strided_address</*data_t*/ {{data_t}},
/*read_t*/ {{read_t}},
/*is_contiguous*/ true>(
{{data_ptr}}, {{data_idx}}, {{tensor_accessor.offset}}, 0, 0);
{% else %}
{{data_ptr}} = get_strided_address</*data_t*/ {{data_t}},
/*read_t*/ {{read_t}},
/*is_contiguous*/ false>(
{{data_ptr}}, {{data_idx}},
{{tensor_accessor.offset}},
{{tensor_accessor.original_total_elements_from_stride_dim}},
{{tensor_accessor.actual_total_elements_from_stride_dim}});
{% endif %}
"""
)
KERNEL_READ_INPUT_TEMPLATE
=
jinja2
.
Template
(
"""
{{read_t}} *{{input_name}} = const_cast<{{read_t}}*>(input{{input_idx}});
{{get_strided_address}}
{{read_t}} tmp_i{{input_idx}} = *{{input_name}};
const {{op_t}}* p_tmp_i{{input_idx}} = reinterpret_cast<const {{op_t}}*>(&tmp_i{{input_idx}});
"""
)
KERNEL_DEFINE_OUTPUTS_TEMPLATE
=
jinja2
.
Template
(
"""
{% for idx in indexes %}
{{read_t}} tmp_o{{idx}};
{{op_t}}* p_tmp_o{{idx}} = reinterpret_cast<{{op_t}}*>(&tmp_o{{idx}});
{% endfor %}
"""
)
KERNEL_WRITE_OUTPUT_TEMPLATE
=
jinja2
.
Template
(
"""
{{get_strided_address}}
*{{output_name}} = tmp_o{{output_idx}};
"""
)
KERNEL_TEMPLATE
=
jinja2
.
Template
(
"""
__global__ void
{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
const {{index_type}} idx = bid * FUSED_ELE_THREAD_SIZE + tid;
const {{index_type}} idx_elem = idx * N_ELEMENTS_PER_THREAD;
if (idx_elem >= n_elements) {
return;
}
{{read_inputs}}
{{define_outputs}}
#pragma unroll
for (int i = 0; i < N_OPS_PER_THREAD; ++i) {
{{fused_funcs}}
}
{{write_outputs}}
}
"""
)
FUNC_DECL_INPUT_PARAM_TEMPLATE
=
jinja2
.
Template
(
"const void* input{{idx}}"
)
FUNC_DECL_OUTPUT_PARAM_TEMPLATE
=
jinja2
.
Template
(
"void* output{{idx}}"
)
KERNEL_CALL_INPUT_PARAM_TEMPLATE
=
jinja2
.
Template
(
"reinterpret_cast<const {{read_t}}*>(input{{idx}})"
)
KERNEL_CALL_OUTPUT_PARAM_TEMPLATE
=
jinja2
.
Template
(
"reinterpret_cast<{{read_t}}*>(output{{idx}})"
)
FUNC_TEMPLATE
=
jinja2
.
Template
(
"""
{{head}}
namespace {
{{constant}}
{{custom_libs}}
{{tensor_accessor_lib}}
{{kernel_function}}
} // namespace
void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims_decl}} {{index_type}} n_elements, {{prefix}}Stream_t stream) {
if (n_elements == 0) {
return;
}
int block_size = static_cast<int>(std::ceil(static_cast<double>(n_elements) / N_ELEMENTS_PER_THREAD / FUSED_ELE_THREAD_SIZE));
{{func_name}}<<<block_size, FUSED_ELE_THREAD_SIZE, 0, stream>>>(
{{kernel_call_output_params}},
{{kernel_call_input_params}},
{{dynamic_dims_call}}
n_elements
);
}
"""
)
FUNC_DECL_TEMPLATE
=
jinja2
.
Template
(
"""
void invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{index_type}} n_elements, {{prefix}}Stream_t stream);
"""
)
FUNC_CALL_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}}{
{{indent}}{{index_type}} {{func_name}}_n_elements = {{calculate_n}};
{{indent}}invoke_{{func_name}}({{output_params}}, {{input_params}}, {{dynamic_dims}} {{func_name}}_n_elements, {{stream}});
{{indent}}}
"""
)
@
dataclass
class
ElementwiseMetaData
:
func_name
:
str
op_t
:
str
args
:
List
[
Tensor
]
outputs
:
List
[
Tensor
]
@
dataclass
class
FusedElementwiseMetaData
:
# Input / output Tensors and TensorAccessors.
inputs
:
List
[
Tensor
]
outputs
:
List
[
Tensor
]
input_accessors
:
List
[
TensorAccessor
]
output_accessors
:
List
[
TensorAccessor
]
# Original input / output Tensors before graph transformation.
# Kept here for elementwise -> fused elementwise Tensor mapping.
original_inputs
:
List
[
Tensor
]
original_outputs
:
List
[
Tensor
]
read_t
:
str
op_t
:
str
data_t
:
str
input_broadcast_sizes
:
List
[
List
[
IntVar
]]
dynamic_dims
:
List
[
IntVar
]
sub_funcs
:
List
[
ElementwiseMetaData
]
def
gen_function_single_thread
(
fused_func_metadata
,
input_names
,
output_names
,
type_converter
,
)
->
str
:
"""Per thread elementwise function codegen."""
tensor_to_expr
:
Dict
[
Tensor
,
str
]
=
{}
body
=
""
for
tensor
,
name
in
zip
(
fused_func_metadata
.
original_inputs
,
input_names
):
tensor_to_expr
[
tensor
]
=
name
tmp_output_idx
:
int
=
0
for
func_metadata
in
fused_func_metadata
.
sub_funcs
:
params
:
List
[
str
]
=
[]
func_op_t
=
func_metadata
.
op_t
input_converter
=
None
output_converter
=
None
if
func_op_t
!=
fused_func_metadata
.
op_t
:
input_converter
=
type_converter
.
get
(
fused_func_metadata
.
op_t
).
get
(
func_op_t
)
output_converter
=
type_converter
.
get
(
func_op_t
).
get
(
fused_func_metadata
.
op_t
)
assert
(
input_converter
is
not
None
),
"Unsupported convertion from {} to {}"
.
format
(
fused_func_metadata
.
op_t
,
func_op_t
)
assert
(
output_converter
is
not
None
),
"Unsupported convertion from {} to {}"
.
format
(
func_op_t
,
fused_func_metadata
.
op_t
)
for
arg
in
func_metadata
.
args
:
if
arg
in
tensor_to_expr
:
param
=
tensor_to_expr
[
arg
]
params
.
append
(
"{}({})"
.
format
(
input_converter
,
param
)
if
input_converter
is
not
None
else
param
)
elif
arg
.
is_a_const_num
():
if
func_op_t
[
-
1
]
==
"2"
:
params
.
append
(
"{}({},{})"
.
format
(
func_op_t
,
str
(
arg
.
_attrs
[
"value"
]),
str
(
arg
.
_attrs
[
"value"
]),
)
)
else
:
params
.
append
(
"{}({})"
.
format
(
func_op_t
,
str
(
arg
.
_attrs
[
"value"
])))
else
:
raise
RuntimeError
(
"Cannot generate expression for node {}, ops: {}"
.
format
(
arg
,
func_metadata
)
)
assert
(
len
(
func_metadata
.
outputs
)
==
1
),
"Operator has more than 1 output! Operator: {}"
.
format
(
func_metadata
)
output
=
func_metadata
.
outputs
[
0
]
func_def
=
"{}({})"
.
format
(
func_metadata
.
func_name
,
","
.
join
(
params
))
func_def
=
(
"{}({})"
.
format
(
output_converter
,
func_def
)
if
output_converter
is
not
None
else
func_def
)
if
len
(
output
.
_attrs
[
"dst_ops"
])
>
1
:
name
=
"tmp_"
+
(
str
)(
tmp_output_idx
)
tmp_output_idx
+=
1
body
+=
"{} {} = {};
\n
"
.
format
(
fused_func_metadata
.
op_t
,
name
,
func_def
)
tensor_to_expr
[
output
]
=
name
else
:
tensor_to_expr
[
output
]
=
func_def
for
tensor
,
name
in
zip
(
fused_func_metadata
.
original_outputs
,
output_names
):
if
tensor
not
in
tensor_to_expr
:
raise
RuntimeError
(
"Cannot generate expression for node {}, outputs: {}"
.
format
(
tensor
,
fused_func_metadata
.
original_outputs
)
)
expr
=
tensor_to_expr
[
tensor
]
body
+=
"{} = {};
\n
"
.
format
(
name
,
expr
)
return
body
def
_get_sub_func_metadata
(
ops
:
List
[
Operator
],
data_t
:
str
,
op_t
:
str
,
backend_spec
:
BackendSpec
)
->
Tuple
[
List
[
ElementwiseMetaData
],
str
]:
candidate_op_types
=
backend_spec
.
get_candidate_op_types
(
op_t
)
func_enums
=
[]
for
op
in
ops
:
func_enum
=
op
.
_attrs
[
"func"
]
func_enums
.
append
(
func_enum
)
funcs
=
backend_spec
.
func_enum_to_func_name
.
get
(
func_enum
)
if
funcs
is
None
:
raise
NotImplementedError
(
"Func {} is not supported!"
.
format
(
func_enum
))
for
candidate_op_t
in
candidate_op_types
:
func_name
=
funcs
.
get
(
candidate_op_t
)
if
func_name
is
not
None
:
candidate_op_types
=
backend_spec
.
get_candidate_op_types
(
candidate_op_t
)
break
if
len
(
candidate_op_types
)
==
0
:
raise
RuntimeError
(
"Cannot find a common rocm data type! candidate_op_types: {}, op_t: {}."
.
format
(
candidate_op_types
,
op_t
)
)
if
op_t
in
set
(
candidate_op_types
):
op_t
=
candidate_op_types
[
0
]
else
:
op_t
=
data_t
candidate_op_types
=
backend_spec
.
get_candidate_op_types
(
op_t
)
sub_func_metadata
=
[]
for
op
in
ops
:
func_enum
=
op
.
_attrs
[
"func"
]
funcs
=
backend_spec
.
func_enum_to_func_name
.
get
(
func_enum
)
func_name
=
None
func_op_t
=
None
for
candidate_op_t
in
candidate_op_types
:
func_name
=
funcs
.
get
(
candidate_op_t
)
if
func_name
is
not
None
:
func_op_t
=
candidate_op_t
break
if
func_name
is
None
:
raise
NotImplementedError
(
"Unsupported func {} and op type {}!"
.
format
(
func_enum
,
op_t
)
)
sub_func_metadata
.
append
(
ElementwiseMetaData
(
func_name
,
func_op_t
,
op
.
_attrs
[
"args"
],
op
.
_attrs
[
"outputs"
]
)
)
return
(
sub_func_metadata
,
op_t
)
def
_get_types_and_sizes
(
inputs
:
List
[
Tensor
],
input_accessors
:
List
[
TensorAccessor
],
output_accessors
:
List
[
TensorAccessor
],
backend_spec
:
BackendSpec
,
)
->
Tuple
[
int
,
List
[
List
[
IntVar
]],
str
]:
"""
Returns Tuple(alignment, input_broadcast_sizes, dtype)
"""
# Handle input broadcast.
output_shape
=
output_accessors
[
0
].
original_shapes
dtype
=
inputs
[
0
].
_attrs
[
"dtype"
]
input_broadcast_sizes
=
[]
min_num_elements
=
None
for
input_accessor
in
input_accessors
:
input_shape
=
input_accessor
.
original_shapes
broadcastable
,
_
=
shape_utils
.
get_broadcast_max_shape
(
output_shape
,
input_shape
)
if
not
broadcastable
:
raise
RuntimeError
(
"Input shape {} is not compatible with output shape {}!"
.
format
(
input_shape
,
output_shape
)
)
num_rightmost_non_broadcast_elements
=
len
(
input_shape
)
extended_input_shape
=
list
(
input_shape
)
if
input_shape
==
output_shape
:
input_broadcast_sizes
.
append
(
None
)
else
:
extended_input_shape
=
[
IntImm
(
1
)]
*
len
(
output_shape
)
extended_input_shape
[
len
(
output_shape
)
-
len
(
input_shape
)
:]
=
input_shape
input_broadcast_sizes
.
append
(
extended_input_shape
)
for
i
in
reversed
(
range
(
len
(
extended_input_shape
))):
if
extended_input_shape
[
i
]
!=
output_shape
[
i
]:
num_rightmost_non_broadcast_elements
-=
i
+
1
break
num_elements_for_alignments
=
shape_utils
.
get_num_rightmost_static_elements
(
extended_input_shape
,
num_rightmost_non_broadcast_elements
)
if
not
min_num_elements
:
min_num_elements
=
num_elements_for_alignments
else
:
min_num_elements
=
min
(
min_num_elements
,
num_elements_for_alignments
)
alignment
=
tensor_accessor_codegen
.
find_max_alignment
(
min_num_elements
,
output_accessors
)
# Note that we use the same alignment for accessing inputs and outputs, although
# they may have different alignment requirements. We may lose perf a little bit,
# but reduce the complexity of our jinja template. We can do some perf
# experiments later to determine if we want to chase more perf gains.
alignment
=
tensor_accessor_codegen
.
find_max_alignment
(
alignment
,
input_accessors
)
return
alignment
,
input_broadcast_sizes
,
dtype
def
_get_dynamic_dims
(
output_accessors
:
List
[
TensorAccessor
])
->
List
[
IntVar
]:
res
=
{}
for
output_accessor
in
output_accessors
:
for
dim
in
output_accessor
.
original_shapes
:
if
not
isinstance
(
dim
,
IntImm
):
res
[
dim
.
_attrs
[
"name"
]]
=
dim
return
res
.
values
()
def
_parse_func_metadata
(
ops
:
List
[
Operator
],
inputs
:
List
[
Tensor
],
outputs
:
List
[
Tensor
],
input_accessors
:
List
[
TensorAccessor
],
output_accessors
:
List
[
TensorAccessor
],
original_inputs
:
List
[
Tensor
],
original_outputs
:
List
[
Tensor
],
backend_spec
:
BackendSpec
,
)
->
FusedElementwiseMetaData
:
alignment
,
input_broadcast_sizes
,
dtype
=
_get_types_and_sizes
(
inputs
,
input_accessors
,
output_accessors
,
backend_spec
)
read_type
=
backend_spec
.
get_backend_type
(
alignment
,
dtype
,
backend_spec
.
read_num_elements_to_backend_type
)
op_type
=
backend_spec
.
get_backend_type
(
alignment
,
dtype
,
backend_spec
.
op_num_elements_to_backend_type
)
data_type
=
backend_spec
.
dtype_to_backend_type
(
dtype
)
sub_func_metadata
,
op_type
=
_get_sub_func_metadata
(
ops
,
data_type
,
op_type
,
backend_spec
)
dynamic_dims
=
_get_dynamic_dims
(
output_accessors
)
return
FusedElementwiseMetaData
(
inputs
,
outputs
,
input_accessors
,
output_accessors
,
original_inputs
,
original_outputs
,
read_type
,
op_type
,
data_type
,
input_broadcast_sizes
,
dynamic_dims
,
sub_func_metadata
,
)
def
_gen_int_var_product_str
(
int_vars
:
List
[
IntVar
],
)
->
str
:
res
=
[]
for
int_var
in
int_vars
:
if
isinstance
(
int_var
,
IntImm
):
res
.
append
(
str
(
int_var
.
_attrs
[
"values"
][
0
]))
elif
isinstance
(
int_var
,
IntVar
):
res
.
append
(
int_var
.
_attrs
[
"name"
])
else
:
raise
RuntimeError
(
"A dim must be an IntVar! Current type: {}"
.
format
(
type
(
int_var
))
)
return
" * "
.
join
(
res
)
def
_gen_input_broadcast_calculator_str
(
input_shape
:
List
[
IntVar
],
output_shape
:
List
[
IntVar
],
)
->
str
:
output_num_elements
=
[]
output_strides
=
[]
input_strides
=
[]
start_idx
=
0
for
i
,
(
input_dim
,
output_dim
)
in
enumerate
(
zip
(
input_shape
,
output_shape
)):
if
input_dim
!=
output_dim
:
assert
input_dim
==
IntImm
(
1
),
"Unexpected shapes! Input: {}, output: {}"
.
format
(
input_shape
,
output_shape
)
input_strides
.
append
(
input_shape
[
i
:])
output_strides
.
append
(
output_shape
[
i
:])
output_num_elements
.
append
(
output_shape
[
start_idx
:])
start_idx
=
i
+
1
if
start_idx
<
len
(
output_shape
):
input_strides
.
append
([
IntImm
(
1
)])
output_strides
.
append
([
IntImm
(
1
)])
output_num_elements
.
append
(
output_shape
[
start_idx
:])
res
=
[]
for
(
output_num_element
,
output_stride
,
input_stride
)
in
zip
(
output_num_elements
,
output_strides
,
input_strides
):
res
.
append
(
"{} % ({}) / ({}) * ({})"
.
format
(
"idx * N_ELEMENTS_PER_THREAD"
,
_gen_int_var_product_str
(
output_num_element
),
_gen_int_var_product_str
(
output_stride
),
_gen_int_var_product_str
(
input_stride
),
)
)
return
" + "
.
join
(
res
)
def
_gen_input_broadcast_size_str
(
input_broadcast_sizes
:
List
[
List
[
IntVar
]],
output_shape
:
List
[
IntVar
],
)
->
List
[
str
]:
res
=
[]
for
input_broadcast_size
in
input_broadcast_sizes
:
if
input_broadcast_size
is
None
:
res
.
append
(
""
)
else
:
res
.
append
(
_gen_input_broadcast_calculator_str
(
input_broadcast_size
,
output_shape
)
)
return
res
def
_gen_dynamic_dim_str
(
index_type
:
str
,
dynamic_dims
:
List
[
IntVar
],
has_type
:
bool
)
->
str
:
type_str
=
index_type
+
" "
if
has_type
else
""
res
=
", "
.
join
([
type_str
+
dim
.
_attrs
[
"name"
]
for
dim
in
dynamic_dims
])
if
res
:
res
+=
", "
return
res
def
_gen_read_inputs_str
(
fused_elementwise_metadata
:
FusedElementwiseMetaData
,
broadcast_sizes
:
List
[
str
]
):
read_inputs
=
[]
for
input_idx
,
(
input_accessor
,
broadcast_size
)
in
enumerate
(
zip
(
fused_elementwise_metadata
.
input_accessors
,
broadcast_sizes
)
):
input_name
=
f
"input_tmp
{
input_idx
}
"
data_idx
=
(
"idx"
if
not
broadcast_size
else
f
"(
{
broadcast_size
}
) / N_ELEMENTS_PER_THREAD"
)
get_strided_addr_str
=
GET_STRIDED_ADDRESS_TEMPLATE
.
render
(
tensor_accessor
=
input_accessor
,
data_ptr
=
input_name
,
data_t
=
fused_elementwise_metadata
.
data_t
,
read_t
=
fused_elementwise_metadata
.
read_t
,
data_idx
=
data_idx
,
)
read_input
=
KERNEL_READ_INPUT_TEMPLATE
.
render
(
get_strided_address
=
get_strided_addr_str
,
input_name
=
input_name
,
input_idx
=
input_idx
,
read_t
=
fused_elementwise_metadata
.
read_t
,
op_t
=
fused_elementwise_metadata
.
op_t
,
)
read_inputs
.
append
(
read_input
)
read_inputs_str
=
"
\n
"
.
join
(
read_inputs
)
return
read_inputs_str
def
_gen_write_outputs_str
(
fused_elementwise_metadata
:
FusedElementwiseMetaData
):
write_outputs
=
[]
for
output_idx
,
output_accessor
in
enumerate
(
fused_elementwise_metadata
.
output_accessors
):
output_name
=
f
"output
{
output_idx
}
"
get_strided_addr_str
=
GET_STRIDED_ADDRESS_TEMPLATE
.
render
(
tensor_accessor
=
output_accessor
,
data_ptr
=
output_name
,
data_t
=
fused_elementwise_metadata
.
data_t
,
read_t
=
fused_elementwise_metadata
.
read_t
,
data_idx
=
"idx"
,
)
write_out
=
KERNEL_WRITE_OUTPUT_TEMPLATE
.
render
(
get_strided_address
=
get_strided_addr_str
,
output_name
=
output_name
,
output_idx
=
output_idx
,
)
write_outputs
.
append
(
write_out
)
write_outputs_str
=
"
\n
"
.
join
(
write_outputs
)
return
write_outputs_str
def
_gen_kernel_function
(
func_attrs
:
Dict
[
str
,
Any
],
index_type
:
str
,
fused_elementwise_metadata
:
FusedElementwiseMetaData
,
backend_datatype_convertors
:
Dict
[
str
,
Dict
[
str
,
str
]],
)
->
str
:
output_params_decl
=
","
.
join
(
[
KERNEL_DECL_OUTPUT_PARAM_TEMPLATE
.
render
(
read_t
=
fused_elementwise_metadata
.
read_t
,
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
outputs
)
]
)
input_params_decl
=
","
.
join
(
[
KERNEL_DECL_INPUT_PARAM_TEMPLATE
.
render
(
read_t
=
fused_elementwise_metadata
.
read_t
,
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
inputs
)
]
)
broadcast_sizes
=
_gen_input_broadcast_size_str
(
fused_elementwise_metadata
.
input_broadcast_sizes
,
fused_elementwise_metadata
.
output_accessors
[
0
].
original_shapes
,
)
read_inputs_str
=
_gen_read_inputs_str
(
fused_elementwise_metadata
,
broadcast_sizes
)
define_outputs
=
KERNEL_DEFINE_OUTPUTS_TEMPLATE
.
render
(
read_t
=
fused_elementwise_metadata
.
read_t
,
op_t
=
fused_elementwise_metadata
.
op_t
,
indexes
=
list
(
range
(
len
(
fused_elementwise_metadata
.
outputs
))),
)
write_outputs_str
=
_gen_write_outputs_str
(
fused_elementwise_metadata
)
input_names
=
[
KERNEL_TMP_INPUT_TEMPLATE
.
render
(
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
inputs
)
]
output_names
=
[
KERNEL_TMP_OUTPUT_TEMPLATE
.
render
(
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
outputs
)
]
fused_funcs
=
gen_function_single_thread
(
fused_elementwise_metadata
,
input_names
,
output_names
,
backend_datatype_convertors
,
)
kernel_func
=
KERNEL_TEMPLATE
.
render
(
func_name
=
func_attrs
[
"name"
],
index_type
=
index_type
,
output_params
=
output_params_decl
,
input_params
=
input_params_decl
,
dynamic_dims
=
_gen_dynamic_dim_str
(
index_type
,
fused_elementwise_metadata
.
dynamic_dims
,
has_type
=
True
),
read_inputs
=
read_inputs_str
,
define_outputs
=
define_outputs
,
write_outputs
=
write_outputs_str
,
fused_funcs
=
fused_funcs
,
)
return
kernel_func
def
fused_elementwise_gen_function
(
func_attrs
:
Dict
[
str
,
Any
],
custom_libs
:
str
,
head_template
:
str
,
backend_spec
:
BackendSpec
,
)
->
str
:
"""Generates fused_elementwise function definition."""
ops
=
func_attrs
[
"elementwise_ops"
]
inputs
=
func_attrs
[
"inputs"
]
outputs
=
func_attrs
[
"outputs"
]
input_accessors
=
func_attrs
[
"input_accessors"
]
output_accessors
=
func_attrs
[
"output_accessors"
]
original_inputs
=
func_attrs
[
"original_inputs"
]
original_outputs
=
func_attrs
[
"original_outputs"
]
fused_elementwise_metadata
=
_parse_func_metadata
(
ops
,
inputs
,
outputs
,
input_accessors
,
output_accessors
,
original_inputs
,
original_outputs
,
backend_spec
,
)
# Dump data types into func_attr for testing purpose.
func_attrs
[
"read_t"
]
=
fused_elementwise_metadata
.
read_t
func_attrs
[
"op_t"
]
=
fused_elementwise_metadata
.
op_t
func_attrs
[
"data_t"
]
=
fused_elementwise_metadata
.
data_t
tensor_accessor_lib
=
tensor_accessor_codegen
.
get_libs
()
tensor_accessor_lib_str
=
"
\n\n
"
+
tensor_accessor_lib
+
"
\n\n
"
kernel_function
=
_gen_kernel_function
(
func_attrs
,
backend_spec
.
index_type
,
fused_elementwise_metadata
,
backend_spec
.
backend_datatype_convertors
,
)
output_params_decl
=
","
.
join
(
[
FUNC_DECL_OUTPUT_PARAM_TEMPLATE
.
render
(
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
outputs
)
]
)
input_params_decl
=
","
.
join
(
[
FUNC_DECL_INPUT_PARAM_TEMPLATE
.
render
(
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
inputs
)
]
)
kernel_call_output_params
=
","
.
join
(
[
KERNEL_CALL_OUTPUT_PARAM_TEMPLATE
.
render
(
read_t
=
fused_elementwise_metadata
.
read_t
,
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
outputs
)
]
)
kernel_call_input_params
=
","
.
join
(
[
KERNEL_CALL_INPUT_PARAM_TEMPLATE
.
render
(
read_t
=
fused_elementwise_metadata
.
read_t
,
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
inputs
)
]
)
constant
=
CONSTANT_TEMPLATE
.
render
(
read_t
=
fused_elementwise_metadata
.
read_t
,
op_t
=
fused_elementwise_metadata
.
op_t
,
data_t
=
fused_elementwise_metadata
.
data_t
,
)
function
=
FUNC_TEMPLATE
.
render
(
prefix
=
backend_spec
.
prefix
,
index_type
=
backend_spec
.
index_type
,
head
=
backend_spec
.
header_src_template
.
render
(
extra_header
=
head_template
),
constant
=
constant
,
custom_libs
=
custom_libs
,
tensor_accessor_lib
=
tensor_accessor_lib_str
,
kernel_function
=
kernel_function
,
func_name
=
func_attrs
[
"name"
],
output_params
=
output_params_decl
,
input_params
=
input_params_decl
,
dynamic_dims_decl
=
_gen_dynamic_dim_str
(
backend_spec
.
index_type
,
fused_elementwise_metadata
.
dynamic_dims
,
has_type
=
True
,
),
dynamic_dims_call
=
_gen_dynamic_dim_str
(
backend_spec
.
index_type
,
fused_elementwise_metadata
.
dynamic_dims
,
has_type
=
False
,
),
kernel_call_output_params
=
kernel_call_output_params
,
kernel_call_input_params
=
kernel_call_input_params
,
)
return
function
def
fused_elementwise_gen_function_decl
(
func_attrs
,
backend_spec
:
BackendSpec
,
):
"""Generates fused_elementwise function declaration."""
func_name
=
func_attrs
[
"name"
]
ops
=
func_attrs
[
"elementwise_ops"
]
inputs
=
func_attrs
[
"inputs"
]
outputs
=
func_attrs
[
"outputs"
]
input_accessors
=
func_attrs
[
"input_accessors"
]
output_accessors
=
func_attrs
[
"output_accessors"
]
original_inputs
=
func_attrs
[
"original_inputs"
]
original_outputs
=
func_attrs
[
"original_outputs"
]
fused_elementwise_metadata
=
_parse_func_metadata
(
ops
,
inputs
,
outputs
,
input_accessors
,
output_accessors
,
original_inputs
,
original_outputs
,
backend_spec
,
)
output_params_decl
=
","
.
join
(
[
FUNC_DECL_OUTPUT_PARAM_TEMPLATE
.
render
(
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
outputs
)
]
)
input_params_decl
=
","
.
join
(
[
FUNC_DECL_INPUT_PARAM_TEMPLATE
.
render
(
idx
=
i
)
for
i
,
_
in
enumerate
(
fused_elementwise_metadata
.
inputs
)
]
)
function_decl
=
FUNC_DECL_TEMPLATE
.
render
(
prefix
=
backend_spec
.
prefix
,
index_type
=
backend_spec
.
index_type
,
func_name
=
func_name
,
output_params
=
output_params_decl
,
input_params
=
input_params_decl
,
dynamic_dims
=
_gen_dynamic_dim_str
(
backend_spec
.
index_type
,
fused_elementwise_metadata
.
dynamic_dims
,
has_type
=
True
,
),
)
return
function_decl
def
fused_elementwise_gen_function_call
(
func_attrs
,
indent
:
str
,
backend_spec
:
BackendSpec
,
):
"""Generates fused_elementwise function call."""
ops
=
func_attrs
[
"elementwise_ops"
]
inputs
=
func_attrs
[
"inputs"
]
outputs
=
func_attrs
[
"outputs"
]
input_accessors
=
func_attrs
[
"input_accessors"
]
output_accessors
=
func_attrs
[
"output_accessors"
]
original_inputs
=
func_attrs
[
"original_inputs"
]
original_outputs
=
func_attrs
[
"original_outputs"
]
fused_elementwise_metadata
=
_parse_func_metadata
(
ops
,
inputs
,
outputs
,
input_accessors
,
output_accessors
,
original_inputs
,
original_outputs
,
backend_spec
,
)
output_params
=
","
.
join
([
output
.
_attrs
[
"name"
]
for
output
in
outputs
])
input_params
=
","
.
join
([
input
.
_attrs
[
"name"
]
for
input
in
inputs
])
num_elements_calculator
=
_gen_int_var_product_str
(
output_accessors
[
0
].
original_shapes
)
return
FUNC_CALL_TEMPLATE
.
render
(
stream
=
backend_spec
.
stream
,
func_name
=
func_attrs
[
"name"
],
index_type
=
backend_spec
.
index_type
,
calculate_n
=
num_elements_calculator
,
output_params
=
output_params
,
input_params
=
input_params
,
dynamic_dims
=
_gen_dynamic_dim_str
(
backend_spec
.
index_type
,
fused_elementwise_metadata
.
dynamic_dims
,
has_type
=
False
,
),
indent
=
indent
,
)
\ No newline at end of file
python/AIT implementation/sample files/gemm_common.py
0 → 100644
View file @
516bbdcb
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Common template for gemm
"""
import
os
import
re
from
collections
import
OrderedDict
from
hashlib
import
sha1
import
jinja2
from
...common
import
gemm_common
from
...target
import
Target
# pylint: disable=C0103,C0415,W0611,C0301
EXTRA_SHAPE_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}}const int64_t stride_a = *a_dim1;
{{indent}}const int64_t stride_b = *b_dim1;
{{indent}}const int64_t stride_c = *c_dim1;
"""
)
INSTANCE_TEMPLATE
=
jinja2
.
Template
(
"""
{{config}}
using {{name}} = {{ config_name }};
"""
)
EXEC_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}}auto op = {{instance}}{};
{{indent}}auto invoker = op.MakeInvoker();
{{indent}}auto argument = op.MakeArgument(
{{problem_args}}
{{indent}});
{{indent}}if(!op.IsSupportedArgument(argument)) {
{{indent}} throw std::runtime_error(
{{indent}} "wrong! device_gemm with the specified compilation parameters does "
{{indent}} "not support this Gemm problem");
{{indent}}}
{% if is_profiler %}
{{indent}}auto workspace_size = op.GetWorkSpaceSize(&argument);
{{indent}}GLOBAL_WORKSPACE_SIZE = workspace_size;
{% endif %}
{{indent}}invoker.Run(argument, StreamConfig{stream, false});
{{indent}}return;
"""
)
EXTRA_HEADER_TEMPLATE
=
jinja2
.
Template
(
"""
{% if gemm_flag == "" %}
#include "include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
{% elif gemm_flag == "permute_m2n3" %}
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% elif "bias" in gemm_flag or has_d0 %}
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp"
{% if gemm_flag == "bias_permute" %}
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
{% elif gemm_flag in ["bias_permute_m2n3", "bias_permute_m3n2"] %}
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
{% endif %}
{% endif %}
"""
)
SRC_TEMPLATE
=
jinja2
.
Template
(
"""
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
{{extra_header}}
{{extra_code}}
{{instances}}
void {{function_name}}(
void * in_ptr,
void * weight_ptr,
void * out_ptr,
{% if "bias" in gemm_flag %}
void * bias_ptr,
{% endif %}
{% if has_d0 %}
void * d0_ptr,
{% endif %}
{% if has_d1 %}
void * d1_ptr,
{% endif %}
{% for idx in range(ndims) %}
int64_t* a_dim{{idx}},
{% endfor %}
{% for idx in range(ndims) %}
int64_t* b_dim{{idx}},
{% endfor %}
{% for idx in range(ndims) %}
int64_t* c_dim{{idx}},
{% endfor %}
{% for idx in range(pdims) %}
const int p_dim{{idx}},
{% endfor %}
hipStream_t stream
) {
{{shape_func}}
{{extra_shape}}
{{input_addr_calculator}}
{{output_addr_calculator}}
{{exec_paths}}
throw std::runtime_error(
"Unsupported workload for this gemm specialization."
);
}
"""
)
FUNC_CALL_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}}{{func_name}}(
{{indent}} {{in_ptr}},
{{indent}} {{weight_ptr}},
{{indent}} {{out_ptr}},
{% if "bias" in gemm_flag %}
{{indent}} {{bias_ptr}},
{% endif %}
{% if d0_ptr != "" %}
{{indent}} {{d0_ptr}},
{% endif %}
{% if d1_ptr != "" %}
{{indent}} {{d1_ptr}},
{% endif %}
{% for dim in adims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in bdims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in cdims %}
{{indent}} {{dim}},
{% endfor %}
{% for dim in pdims %}
{{indent}} {{dim}},
{% endfor %}
{{indent}} stream
{{indent}});
"""
)
PROBLEM_ARGS_TEMPLATE
=
jinja2
.
Template
(
"""
{{indent}} static_cast<ck::half_t *>(in_ptr),
{{indent}} static_cast<ck::half_t *>(weight_ptr),
{% if gemm_flag == "bias_permute" %}
{{indent}} static_cast<ck::half_t *>(bias_ptr),
{% elif gemm_flag == "bias_permute_m2n3" %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif gemm_flag == "permute_m2n3" %}
{{indent}} {},
{% else %}
{% if "bias" in gemm_flag and not has_d0 %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif has_d0 and not has_d1 %}
{{indent}} std::array<const void*, 2>{static_cast<ck::half_t *>(bias_ptr),
static_cast<ck::half_t *>(d0_ptr)},
{% elif has_d1 %}
{{indent}} std::array<const void*, 3>{static_cast<ck::half_t *>(bias_ptr),
static_cast<ck::half_t *>(d0_ptr),
static_cast<ck::half_t *>(d1_ptr)},
{% endif %}
{% endif %}
{{indent}} static_cast<ck::half_t *>(out_ptr),
{% if gemm_flag not in ["permute_m2n3", "bias_permute_m2n3", "bias_permute_m3n2"] %}
{{indent}} M,
{{indent}} N,
{{indent}} K,
{{indent}} stride_a,
{{indent}} stride_b,
{% endif %}
{% if gemm_flag == "bias_permute" %}
{{indent}} {M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1},
{{indent}} {M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1},
{% elif gemm_flag in ["permute_m2n3", "bias_permute_m2n3", "bias_permute_m3n2"] %}
{{indent}} a_ms_ks_lengths,
{{indent}} a_ms_ks_strides,
{{indent}} b_ns_ks_lengths,
{{indent}} b_ns_ks_strides,
{% if gemm_flag == "permute_m2n3" %}
{{indent}} {},
{{indent}} {},
{% else %}
{{indent}} std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
{{indent}} std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
{% endif %}
{{indent}} e_ms_ns_lengths,
{{indent}} e_ms_ns_strides,
{% else %}
{% if "bias" in gemm_flag and not has_d0 %}
{{indent}} std::array<ck::index_t, 1>{0},
{% elif has_d0 and not has_d1 %}
{{indent}} std::array<ck::index_t, 2>{0, static_cast<int>(stride_c)},
{% elif has_d1 %}
{{indent}} std::array<ck::index_t, 3>{0, static_cast<int>(stride_c), static_cast<int>(stride_c)},
{% endif %}
{{indent}} stride_c,
{% endif %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{},
{{indent}} ck::tensor_operation::element_wise::PassThrough{},
{% if gemm_flag == "" %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{}
{% elif gemm_flag == "permute_m2n3" %}
{{indent}} ck::tensor_operation::element_wise::PassThrough{}
{% elif gemm_flag == "bias" or "bias_permute" in gemm_flag %}
{{indent}} ck::tensor_operation::element_wise::Add{}
{% elif gemm_flag == "bias_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddRelu{}
{% elif gemm_flag == "bias_fast_gelu" %}
{{indent}} ck::tensor_operation::element_wise::AddFastGelu{}
{% elif gemm_flag == "bias_swish" %}
{{indent}} ck::tensor_operation::element_wise::AddHardswish{}
{% elif gemm_flag == "bias_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddTanh{}
{% elif gemm_flag == "bias_sigmoid" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoid{}
{% elif gemm_flag == "bias_add" %}
{{indent}} ck::tensor_operation::element_wise::AddAdd{}
{% elif gemm_flag == "bias_mul" %}
{{indent}} ck::tensor_operation::element_wise::AddMul{}
{% elif gemm_flag == "bias_mul_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddMulTanh{}
{% elif gemm_flag == "bias_add_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddRelu{}
{% elif gemm_flag == "bias_add_fast_gelu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddFastGelu{}
{% elif gemm_flag == "bias_sigmoid_mul" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoidMul{}
{% elif gemm_flag == "bias_sigmoid_mul_tanh" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoidMulTanh{}
{% elif gemm_flag == "bias_mul_add" %}
{{indent}} ck::tensor_operation::element_wise::AddMulAdd{}
{% elif gemm_flag == "bias_add_add" %}
{{indent}} ck::tensor_operation::element_wise::AddAddAdd{}
{% elif gemm_flag == "bias_add_add_relu" %}
{{indent}} ck::tensor_operation::element_wise::AddAddAddRelu{}
{% endif %}
"""
)
TENSOR_DECL_TEMPLATE
=
jinja2
.
Template
(
"""
int64_t a_ptr_sz = M*K;
int64_t b_ptr_sz = N*K;
int64_t c_ptr_sz = M*N;
int64_t ptr_max_sz = std::max({a_ptr_sz, b_ptr_sz, c_ptr_sz});
// TODO: special pool size for 8M L2 cache
// need to tune it for other devices
int64_t mem_pool_sz = std::max(2, std::min(64, int((1 << 23) / ptr_max_sz)));
memory_pool->AllocateHalfTensor(a_ptr_sz, mem_pool_sz); // x: index 0
memory_pool->AllocateHalfTensor(b_ptr_sz, mem_pool_sz); // w: index 1
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // y: index 2
{% if "bias" in gemm_flag %}
memory_pool->AllocateHalfTensor(N, mem_pool_sz); // b: index 3
{% endif %}
{% if has_d0 %}
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d0 ptr: index 4
{% endif %}
{% if has_d1 %}
memory_pool->AllocateHalfTensor(c_ptr_sz, mem_pool_sz); // d1 ptr: index 5
{% endif %}
"""
)
STRUCTS_DEF_TEMPLATE
=
jinja2
.
Template
(
"""
struct ProfilerMemoryPool {
ProfilerMemoryPool() {
std::random_device rd;
gen = std::mt19937(rd());
uniform_dist = std::uniform_int_distribution<int64_t>(1, 48964896);
offsets.reserve(512);
strides.reserve(512);
copies.reserve(512);
ptrs.reserve(512);
}
~ProfilerMemoryPool() {
for(int i = 0; i < ptrs.size(); i++){
hipFree(ptrs[i]);
}
}
template <typename DType>
DType* AllocateGaussianTensor(int64_t size) {
size_t length = size * sizeof(DType);
DType *d_x;
hipMalloc(&d_x, length);
float mean = 0.0f;
float stddev = 1.0f;
uint64_t seed = uniform_dist(gen);
rocrand_set_seed(generator, seed);
rocrand_generate_normal(generator, reinterpret_cast<float*>(d_x), size, mean, stddev);
return d_x;
}
ck::half_t* AllocateHalfGaussianTensor(int64_t size) {
return reinterpret_cast<ck::half_t*>(
AllocateGaussianTensor<ck::half_t>(size));
}
int AllocateHalfTensor(int64_t size, int64_t copy) {
offsets.push_back(0);
strides.push_back(size);
copies.push_back(copy);
auto ptr = AllocateHalfGaussianTensor(size * copy);
ptrs.push_back(reinterpret_cast<void*>(ptr));
return ptrs.size() - 1;
}
ck::half_t* RequestHalfTensorByIdx(int idx) {
auto copy = copies.at(idx);
auto offset = offsets.at(idx);
auto stride = strides.at(idx);
ck::half_t* ptr = reinterpret_cast<ck::half_t*>(ptrs.at(idx));
ptr += offset;
offset += stride;
if (offset == copy * stride) {
offset = 0;
}
offsets[idx] = offset;
return ptr;
}
std::vector<int64_t> offsets;
std::vector<int64_t> strides;
std::vector<int64_t> copies;
std::vector<void*> ptrs;
std::mt19937 gen;
std::uniform_int_distribution<int64_t> uniform_dist;
rocrand_generator generator;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p) const
{
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl() {
hipGetErrorString(hipEventCreate(&mStart));
hipGetErrorString(hipEventCreate(&mEnd));
}
~KernelTimerImpl() {
hipGetErrorString(hipEventDestroy(mStart));
hipGetErrorString(hipEventDestroy(mEnd));
}
void Start() {
hipGetErrorString(hipDeviceSynchronize());
hipGetErrorString(hipEventRecord(mStart, nullptr));
}
void End() {
hipGetErrorString(hipEventRecord(mEnd, nullptr));
hipGetErrorString(hipEventSynchronize(mEnd));
}
float GetElapsedTime() const {
float time;
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
return time;
}
hipEvent_t mStart, mEnd;
};
// >>> hack end
"""
)
PROFILER_TEMPLATE
=
jinja2
.
Template
(
"""
size_t GLOBAL_WORKSPACE_SIZE = 0;
{{op_func}}
{{structs_def}}
int main(int argc, char** argv) {
if (argc < 4) {
throw std::runtime_error("wrong params");
}
{{args_parse}}
auto memory_pool = std::make_unique<ProfilerMemoryPool>();
hipStream_t stream = nullptr;
{{tensor_decl}}
// TODO: random init
// warmup
for(int i = 0; i < 3; ++i) {
{{func_call}}
}
// run
auto timer = new KernelTimerImpl();
timer->Start();
for(int i = 0; i < 5; ++i) {
{{func_call}}
}
timer->End();
std::cout << "WS:" <<GLOBAL_WORKSPACE_SIZE<<std::endl;
std::cout << "TIME:" << timer->GetElapsedTime() << std::endl;
delete(timer);
}
"""
)
FUNC_DECL_TEMPLATE
=
jinja2
.
Template
(
"""
void {{func_name}}(
void *,
void *,
void *,
{% if "bias" in gemm_flag %}
void *,
{% endif %}
{% if has_d0 %}
void *,
{% endif %}
{% if has_d1 %}
void *,
{% endif %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(ndims) %}
int64_t*,
{% endfor %}
{% for idx in range(pdims) %}
const int,
{% endfor %}
hipStream_t
);
"""
)
def
has_d0
(
func_attrs
):
return
func_attrs
.
get
(
"num_sources"
,
0
)
>=
1
def
has_d1
(
func_attrs
):
return
func_attrs
.
get
(
"num_sources"
,
0
)
>=
2
def
emit_instance
(
op
):
"""Emit instance."""
import
ck_lib
# noqa: F401
op_def
=
op
.
emit
()
return
op_def
def
extract_config
(
op_kind
,
extra_kind
,
f_proc_op
):
"""Extract (operation name, operation instance) pair
from all operation candidates.
Parameters
----------
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
f_prop_op: function
Used to filter operation.
Returns
-------
Dict
Extracted (operation name, operation instance) pair.
"""
gemm_ops
=
OrderedDict
()
extract_ops
=
list
(
Target
.
current
().
_operators
[
op_kind
][
extra_kind
].
items
())
for
key
,
value
in
extract_ops
:
op
=
value
[
0
]
ret
=
f_proc_op
(
op
)
if
len
(
ret
)
>
0
:
for
op_inst
in
ret
:
gemm_ops
[
key
]
=
op_inst
return
gemm_ops
def
extract_config_name
(
config
):
"""Exract name from the statement, e.g. 'model' for 'using model = xxx'.
Parameters
----------
config : str
Configuration as a string in the format of 'using model = xxx'.
Returns
-------
str
Extracted name from the statement.
Raises
------
RuntimeError
Invalid config.
"""
pattern
=
re
.
compile
(
r
"\s*using\s(.*?)\s="
)
decl
=
config
.
split
(
"
\n
"
)[
1
]
match
=
pattern
.
match
(
decl
)
if
match
is
None
:
raise
RuntimeError
(
"Invalid config:
\n
"
+
config
)
return
match
.
groups
()[
0
]
def
gen_profiler
(
func_attrs
,
workdir
,
dim_info_dict
,
args_parse
,
gemm_flag
,
extra_code
=
""
,
ndims
=
2
,
extra_shape_template
=
EXTRA_SHAPE_TEMPLATE
,
problem_args_template
=
PROBLEM_ARGS_TEMPLATE
,
extra_header_template
=
EXTRA_HEADER_TEMPLATE
,
tensor_decl_template
=
TENSOR_DECL_TEMPLATE
,
):
"""Generates standalone executables for profiler.
Parameters
----------
func_attrs : Dict
Operation attributes.
workdir : str
Directory to store the generated outputs.
dim_info_dict: Dict[str, DimInfo]
Generated from gemm._extract_dims().
Used to store mapping between dim_names to input / output tensor dims.
args_parse: str
Profiler input argument parser.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu','bias_sigmoid','bias_add_relu'.
extra_code : str
Extra code for self-defined operators.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm
extra_shape_template: jinja2.Template
Shape evaluation template.
problem_args_template: jinja2.Template
Problem args template for profiler.
extra_header_template: jinja2.Template
Extra header template as we have different headers for gemm and bmm.
tensor_decl_template: jinja2.Template
Tensor declaration template.
"""
op_type
=
func_attrs
[
"op"
]
op_instance
=
func_attrs
[
"op_instance"
]
# shape function
op_func_shape
=
gemm_common
.
gen_shape_eval_code
(
indent
=
2
,
dtype
=
"int64_t"
,
dim_info_dict
=
dim_info_dict
,
is_ptr
=
True
)
adims
=
[
"&a_dim"
+
str
(
i
)
for
i
in
range
(
ndims
)]
bdims
=
[
"&b_dim"
+
str
(
i
)
for
i
in
range
(
ndims
)]
cdims
=
[
"&c_dim"
+
str
(
i
)
for
i
in
range
(
ndims
)]
pdims
=
[]
if
func_attrs
.
get
(
"shape"
)
is
not
None
:
pdims
=
[
"p_dim"
+
str
(
i
)
for
i
in
range
(
len
(
func_attrs
[
"shape"
]))]
extra_shape_func
=
extra_shape_template
.
render
(
indent
=
" "
)
file_pairs
=
[]
has_d0_flag
=
has_d0
(
func_attrs
)
has_d1_flag
=
has_d1
(
func_attrs
)
for
op_name
,
op
in
op_instance
.
items
():
config
=
emit_instance
(
op
)
config_name
=
extract_config_name
(
config
)
instance
=
INSTANCE_TEMPLATE
.
render
(
name
=
"DeviceGemmInstance"
,
config_name
=
config_name
,
config
=
config
)
problem_args
=
problem_args_template
.
render
(
indent
=
" "
,
gemm_flag
=
gemm_flag
,
has_d0
=
has_d0_flag
,
has_d1
=
has_d1_flag
,
)
exec_program
=
EXEC_TEMPLATE
.
render
(
indent
=
" "
,
instance
=
"DeviceGemmInstance"
,
problem_args
=
problem_args
,
is_profiler
=
True
,
)
extra_header
=
extra_header_template
.
render
(
gemm_flag
=
gemm_flag
,
has_d0
=
has_d0_flag
)
op_func
=
SRC_TEMPLATE
.
render
(
instances
=
instance
,
function_name
=
"gemm"
,
ndims
=
ndims
,
pdims
=
len
(
pdims
),
has_d0
=
has_d0_flag
,
has_d1
=
has_d1_flag
,
shape_func
=
op_func_shape
,
extra_shape
=
extra_shape_func
,
exec_paths
=
exec_program
,
extra_code
=
extra_code
,
gemm_flag
=
gemm_flag
,
extra_header
=
extra_header
,
)
structs_def
=
STRUCTS_DEF_TEMPLATE
.
render
()
tensor_decl
=
tensor_decl_template
.
render
(
gemm_flag
=
gemm_flag
,
has_d0
=
has_d0_flag
,
has_d1
=
has_d1_flag
)
func_call
=
FUNC_CALL_TEMPLATE
.
render
(
func_name
=
"gemm"
,
in_ptr
=
"(void *) memory_pool->RequestHalfTensorByIdx(0)"
,
weight_ptr
=
"(void *) memory_pool->RequestHalfTensorByIdx(1)"
,
out_ptr
=
"(void *) memory_pool->RequestHalfTensorByIdx(2)"
,
bias_ptr
=
"(void *) memory_pool->RequestHalfTensorByIdx(3)"
,
d0_ptr
=
"(void *) memory_pool->RequestHalfTensorByIdx(4)"
if
has_d0_flag
else
""
,
d1_ptr
=
"(void *) memory_pool->RequestHalfTensorByIdx(5)"
if
has_d1_flag
else
""
,
adims
=
adims
,
bdims
=
bdims
,
cdims
=
cdims
,
pdims
=
pdims
,
gemm_flag
=
gemm_flag
,
)
code
=
PROFILER_TEMPLATE
.
render
(
structs_def
=
structs_def
,
op_func
=
op_func
,
args_parse
=
args_parse
,
tensor_decl
=
tensor_decl
,
func_call
=
func_call
,
)
prefix
=
os
.
path
.
join
(
workdir
,
"profiler"
,
op_type
)
if
not
os
.
path
.
exists
(
prefix
):
os
.
makedirs
(
prefix
)
src_path
=
os
.
path
.
join
(
prefix
,
op_name
+
".cpp"
)
obj_path
=
os
.
path
.
join
(
prefix
,
op_name
)
if
os
.
path
.
exists
(
obj_path
):
continue
with
open
(
src_path
,
"w"
)
as
fo
:
fo
.
write
(
code
)
file_pairs
.
append
((
src_path
,
obj_path
))
return
file_pairs
def
gen_function
(
func_attrs
,
exec_cond_template
,
dim_info_dict
,
gemm_flag
,
extra_code
=
""
,
ndims
=
2
,
extra_shape_template
=
EXTRA_SHAPE_TEMPLATE
,
problem_args_template
=
PROBLEM_ARGS_TEMPLATE
,
extra_header_template
=
EXTRA_HEADER_TEMPLATE
,
input_addr_calculator
=
""
,
output_addr_calculator
=
""
,
):
"""Generate function body.
Parameters
----------
func_attrs : Dict
Operation attributes.
exec_cond_template : jinja2.Template
Generates if statement to execute kernel.
dim_info_dict: Dict[str, DimInfo]
Generated from gemm._extract_dims().
Used to store mapping between dim_names to input / output tensor dims.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu','bias_add_relu'.
extra_code : str
Extra code for self-defined operators.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm.
extra_shape_template: jinja2.Template
Shape evaluation template.
extra_header_template: jinja2.Template
Extra header template as we have different headers for gemm and bmm.
input_addr_calculator : str
Used to adjust input address based on input tensor accessors if accessors exist
output_addr_calculator : str
Used to adjust output address based on output tensor accessors if accessors exist
Returns
-------
str
The rendered template of generated function body.
"""
func_name
=
func_attrs
[
"name"
]
exec_path
=
func_attrs
[
"exec_path"
]
op_instance
=
func_attrs
[
"op_instance"
]
inst_def_flag
=
set
()
instances
=
{}
instance_decl
=
""
has_d0_flag
=
has_d0
(
func_attrs
)
has_d1_flag
=
has_d1
(
func_attrs
)
for
key
,
value
in
exec_path
.
items
():
fname
=
"f"
+
sha1
(
key
.
encode
()).
hexdigest
()
algo
=
value
.
algo
if
algo
not
in
inst_def_flag
:
config
=
emit_instance
(
op_instance
[
algo
])
inst_def_flag
.
add
(
algo
)
else
:
config
=
""
inst
=
INSTANCE_TEMPLATE
.
render
(
config
=
config
,
name
=
fname
,
config_name
=
extract_config_name
(
config
)
)
instances
[
key
]
=
inst
instance_decl
+=
inst
extra_shape_func
=
extra_shape_template
.
render
(
indent
=
" "
)
shape_eval_func
=
gemm_common
.
gen_shape_eval_code
(
indent
=
1
,
dtype
=
"int64_t"
,
dim_info_dict
=
dim_info_dict
,
is_ptr
=
True
)
exec_paths
=
""
for
key
,
_
in
instances
.
items
():
fname
=
"f"
+
sha1
(
key
.
encode
()).
hexdigest
()
problem_args
=
problem_args_template
.
render
(
indent
=
" "
,
gemm_flag
=
gemm_flag
,
has_d0
=
has_d0_flag
,
has_d1
=
has_d1_flag
,
)
program
=
EXEC_TEMPLATE
.
render
(
indent
=
" "
,
instance
=
fname
,
problem_args
=
problem_args
,
is_profiler
=
False
,
)
exec_inst
=
exec_cond_template
.
render
(
indent
=
" "
,
cond
=
key
,
program
=
program
)
exec_paths
+=
exec_inst
extra_header
=
extra_header_template
.
render
(
gemm_flag
=
gemm_flag
,
has_d0
=
has_d0
(
func_attrs
)
)
pdims
=
len
(
func_attrs
[
"shape"
])
if
func_attrs
.
get
(
"shape"
)
is
not
None
else
0
return
SRC_TEMPLATE
.
render
(
instances
=
instance_decl
,
function_name
=
func_name
,
shape_func
=
shape_eval_func
,
extra_shape
=
extra_shape_func
,
input_addr_calculator
=
input_addr_calculator
,
output_addr_calculator
=
output_addr_calculator
,
exec_paths
=
exec_paths
,
extra_code
=
extra_code
,
extra_header
=
extra_header
,
gemm_flag
=
gemm_flag
,
ndims
=
ndims
,
pdims
=
pdims
,
has_d0
=
has_d0_flag
,
has_d1
=
has_d1_flag
,
)
def
gen_function_decl
(
func_name
,
gemm_flag
,
ndims
=
2
,
pdims
=
0
,
has_d0
=
""
,
has_d1
=
""
):
"""Generates function declarations.
Parameters
----------
func_attrs : Dict
Operation attributes.
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu'.
ndims : int
Number of dims for each parameter, 2 for gemm, 3 for bmm.
Returns
-------
str
The rentered template of function declaration.
"""
return
FUNC_DECL_TEMPLATE
.
render
(
func_name
=
func_name
,
gemm_flag
=
gemm_flag
,
ndims
=
ndims
,
pdims
=
pdims
,
has_d0
=
has_d0
,
has_d1
=
has_d1
,
)
def
gen_function_call
(
func_attrs
,
indent
=
" "
,
gemm_flag
=
""
):
"""Generates function call.
Parameters
----------
func_attrs : Dict
Stores the operation attributes.
indent : str, optional
Indent for codegen, target dependent e.g. C++, python, etc., by default " ".
gemm_flag : str
Flag telling which backend should be generated. options are '','bias','bias_relu'.
Returns
-------
str
The rendered template of generated function call.
"""
a
=
func_attrs
[
"inputs"
][
0
]
b
=
func_attrs
[
"inputs"
][
1
]
c
=
func_attrs
[
"outputs"
][
0
]
bias_ptr
=
""
if
"bias"
in
gemm_flag
:
bias
=
func_attrs
[
"inputs"
][
2
]
bias_ptr
=
bias
.
_attrs
[
"name"
]
d0_ptr
=
""
if
has_d0
(
func_attrs
):
d0
=
func_attrs
[
"inputs"
][
3
]
d0_ptr
=
d0
.
_attrs
[
"name"
]
d1_ptr
=
""
if
has_d1
(
func_attrs
):
d1
=
func_attrs
[
"inputs"
][
4
]
d1_ptr
=
d1
.
_attrs
[
"name"
]
adims
=
[
"&"
+
dim
.
_attrs
[
"name"
]
for
dim
in
func_attrs
[
"input_accessors"
][
0
].
original_shapes
]
bdims
=
[
"&"
+
dim
.
_attrs
[
"name"
]
for
dim
in
func_attrs
[
"input_accessors"
][
1
].
original_shapes
]
cdims
=
[
"&"
+
dim
.
_attrs
[
"name"
]
for
dim
in
func_attrs
[
"output_accessors"
][
0
].
original_shapes
]
pdims
=
[]
if
func_attrs
.
get
(
"shape"
)
is
not
None
:
pdims
=
list
(
func_attrs
[
"shape"
])
return
FUNC_CALL_TEMPLATE
.
render
(
func_name
=
func_attrs
[
"name"
],
in_ptr
=
a
.
_attrs
[
"name"
],
weight_ptr
=
b
.
_attrs
[
"name"
],
out_ptr
=
c
.
_attrs
[
"name"
],
bias_ptr
=
bias_ptr
,
d0_ptr
=
d0_ptr
,
d1_ptr
=
d1_ptr
,
adims
=
adims
,
bdims
=
bdims
,
cdims
=
cdims
,
pdims
=
pdims
,
indent
=
indent
,
gemm_flag
=
gemm_flag
,
)
def
default_fproc_f16
(
*
,
op
,
a_layout
,
b_layout
,
c_layout
):
"""Filter the input operation by layouts.
Parameters
----------
op: operation
aitemplate operation
a_layout: ck_lib.library.LayoutType
a layout type.
b_layout: ck_lib.library.LayoutType
b layout type.
c_layout: ck_lib.library.LayoutType
c layout type.
Returns
-------
List
List of filtered op (can be empty).
"""
import
copy
import
ck_lib
ret
=
[]
data_type
=
ck_lib
.
library
.
DataType
.
f16
acc_type
=
ck_lib
.
library
.
DataType
.
f32
if
(
op
.
A
.
element
==
data_type
and
op
.
B
.
element
==
data_type
and
op
.
C
.
element
==
data_type
and
op
.
accumulator_type
()
==
acc_type
and
op
.
A
.
layout
==
a_layout
and
op
.
B
.
layout
==
b_layout
and
op
.
C
.
layout
==
c_layout
):
ret
+=
[
copy
.
deepcopy
(
op
)]
return
ret
def
make_fproc_f16
(
func_attrs
,
layout
,
op_kind
,
extra_kind
):
"""This function sets a callback for processing the epilogue of the kernel
associated with func_attrs.
Parameters
----------
func_attrs: Dictionary
kernel attributes dictionary
layout: layout object
kernel layout
op_kind : ck_lib.library.OperationKind
Operation kind.
extra_kind : ck_lib.library.[AnyKind]
Used to as extra flag to distinguish kernels.
E.g. bias_add_relu vs. add_relu_bias
"""
def
fproc_f16
(
op
):
a_layout
,
b_layout
,
c_layout
=
layout
.
ck_lib_layouts
()
return
default_fproc_f16
(
op
=
op
,
a_layout
=
a_layout
,
b_layout
=
b_layout
,
c_layout
=
c_layout
,
)
func_attrs
[
"op_instance"
]
=
extract_config
(
op_kind
,
extra_kind
,
fproc_f16
)
\ No newline at end of file
python/AIT implementation/sample files/gemm_rrr.cpp
0 → 100644
View file @
516bbdcb
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using
gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
,
1
,
256
,
// block_size
64
,
// m_per_block
128
,
// n_per_block
32
,
// k_per_block
8
,
// ak1
2
,
// bk1
32
,
// m_per_xdl
32
,
// n_per_xdl
1
,
// m_xdl_per_wave
2
,
// n_xdl_per_wave
ck
::
Sequence
<
4
,
64
,
1
>
,
// thread_cluster_length
ck
::
Sequence
<
1
,
0
,
2
>
,
// thread_cluster_arrange_order
ck
::
Sequence
<
1
,
0
,
2
>
,
// src_access_order
2
,
// src_vector_dim
8
,
// src_scalar_per_vector
8
,
// dst_scalar_per_vector
1
,
// add_extra_dim
ck
::
Sequence
<
8
,
32
,
1
>
,
// thread_cluster_length
ck
::
Sequence
<
0
,
2
,
1
>
,
// thread_cluster_arrange_order
ck
::
Sequence
<
0
,
2
,
1
>
,
// src_access_order
1
,
// src_vector_dim
4
,
// src_scalar_per_vector
2
,
// dst_scalar_per_vector
0
,
// add_extra_dim
1
,
// m_xdl_per_wave
1
,
// n_xdl_per_wave
ck
::
Sequence
<
1
,
32
,
1
,
8
>
,
// m_n_block_wave_per_xdl
8
// scalar_per_vector
>
;
using
fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0
=
gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT
;
void
gemm_rrr_3
(
void
*
in_ptr
,
void
*
weight_ptr
,
void
*
out_ptr
,
int64_t
*
a_dim0
,
int64_t
*
a_dim1
,
int64_t
*
b_dim0
,
int64_t
*
b_dim1
,
int64_t
*
c_dim0
,
int64_t
*
c_dim1
,
hipStream_t
stream
)
{
ck
::
index_t
M
=
(
*
a_dim0
);
ck
::
index_t
N
=
(
*
b_dim1
);
ck
::
index_t
K
=
(
*
a_dim1
);
int64_t
offset_a
=
0
;
int64_t
offset_b
=
0
;
int64_t
offset_c
=
0
;
ck
::
index_t
stride_a
=
*
a_dim1
;
ck
::
index_t
stride_b
=
*
b_dim1
;
ck
::
index_t
stride_c
=
*
c_dim1
;
if
(
M
==
256
&&
N
==
32
&&
K
==
128
)
{
auto
op
=
fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
static_cast
<
ck
::
half_t
*>
(
in_ptr
)
+
offset_a
,
static_cast
<
ck
::
half_t
*>
(
weight_ptr
)
+
offset_b
,
static_cast
<
ck
::
half_t
*>
(
out_ptr
)
+
offset_c
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
if
(
!
op
.
IsSupportedArgument
(
argument
))
{
LOG
(
FATAL
)
<<
"wrong! "
<<
op
.
GetTypeString
()
<<
" with the specified compilation parameters does not support this Gemm problem."
;
}
invoker
.
Run
(
argument
,
StreamConfig
{
stream
,
false
});
return
;
}
LOG
(
FATAL
)
<<
"Unsupported workload for this gemm specialization."
;
}
\ No newline at end of file
python/AIT implementation/sample files/gemm_rrr_3.cpp
0 → 100644
View file @
516bbdcb
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
// #include <half.hpp>
#include <random>
#include <rocrand/rocrand.h>
#include "logging.h"
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using
gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
,
1
,
256
,
// block_size
64
,
// m_per_block
128
,
// n_per_block
32
,
// k_per_block
8
,
// ak1
2
,
// bk1
32
,
// m_per_xdl
32
,
// n_per_xdl
1
,
// m_xdl_per_wave
2
,
// n_xdl_per_wave
ck
::
Sequence
<
4
,
64
,
1
>
,
// thread_cluster_length
ck
::
Sequence
<
1
,
0
,
2
>
,
// thread_cluster_arrange_order
ck
::
Sequence
<
1
,
0
,
2
>
,
// src_access_order
2
,
// src_vector_dim
8
,
// src_scalar_per_vector
8
,
// dst_scalar_per_vector
1
,
// add_extra_dim
ck
::
Sequence
<
8
,
32
,
1
>
,
// thread_cluster_length
ck
::
Sequence
<
0
,
2
,
1
>
,
// thread_cluster_arrange_order
ck
::
Sequence
<
0
,
2
,
1
>
,
// src_access_order
1
,
// src_vector_dim
4
,
// src_scalar_per_vector
2
,
// dst_scalar_per_vector
0
,
// add_extra_dim
1
,
// m_xdl_per_wave
1
,
// n_xdl_per_wave
ck
::
Sequence
<
1
,
32
,
1
,
8
>
,
// m_n_block_wave_per_xdl
8
// scalar_per_vector
>
;
using
fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0
=
gemm_2_hhh_TTT_256_64_128_32_8_2_32_32_1_2_PT
;
void
gemm_rrr_3
(
void
*
in_ptr
,
void
*
weight_ptr
,
void
*
out_ptr
,
int64_t
*
a_dim0
,
int64_t
*
a_dim1
,
int64_t
*
b_dim0
,
int64_t
*
b_dim1
,
int64_t
*
c_dim0
,
int64_t
*
c_dim1
,
hipStream_t
stream
)
{
ck
::
index_t
M
=
(
*
a_dim0
);
ck
::
index_t
N
=
(
*
b_dim1
);
ck
::
index_t
K
=
(
*
a_dim1
);
int64_t
offset_a
=
0
;
int64_t
offset_b
=
0
;
int64_t
offset_c
=
0
;
ck
::
index_t
stride_a
=
*
a_dim1
;
ck
::
index_t
stride_b
=
*
b_dim1
;
ck
::
index_t
stride_c
=
*
c_dim1
;
if
(
M
==
256
&&
N
==
32
&&
K
==
128
)
{
auto
op
=
fe7d3cbb34ba481ca532c1fecec7ec7cc5fbb35d0
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
static_cast
<
ck
::
half_t
*>
(
in_ptr
)
+
offset_a
,
static_cast
<
ck
::
half_t
*>
(
weight_ptr
)
+
offset_b
,
static_cast
<
ck
::
half_t
*>
(
out_ptr
)
+
offset_c
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
if
(
!
op
.
IsSupportedArgument
(
argument
))
{
LOG
(
FATAL
)
<<
"wrong! "
<<
op
.
GetTypeString
()
<<
" with the specified compilation parameters does not support this Gemm problem."
;
}
invoker
.
Run
(
argument
,
StreamConfig
{
stream
,
false
});
return
;
}
LOG
(
FATAL
)
<<
"Unsupported workload for this gemm specialization."
;
}
\ No newline at end of file
python/AIT implementation/sample files/layernorm.cpp
0 → 100644
View file @
516bbdcb
size_t
GLOBAL_WORKSPACE_SIZE
=
0
;
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <random>
#include <rocrand/rocrand.h>
#include "include/ck/utility/print.hpp"
#include "library/include/ck/library/utility/device_memory.hpp"
#include "library/include/ck/library/utility/host_tensor.hpp"
#include "library/include/ck/library/utility/host_tensor_generator.hpp"
#include "include/ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/reduction_operator.hpp"
#include "include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
using
layernorm_rank2_256_1_256_1_8_1_8_1_8_1_8_8
=
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
2
,
1
,
256
,
// block_size
1
,
// m_cluster_size
256
,
// k_cluster_size
1
,
// m_slice_size
8
,
// k_slice_size
1
,
// in_src_dim
8
,
// in_src_size
1
,
// gamma_src_dim
8
,
// gamma_src_size
1
,
// beta_src_dim
8
,
// beta_src_size
8
// out_dst_size
>
;
using
DeviceInstance
=
layernorm_rank2_256_1_256_1_8_1_8_1_8_1_8_8
;
void
layernorm_0
(
void
*
input
,
void
*
gamma
,
void
*
beta
,
void
*
output
,
int64_t
*
in_0
,
int64_t
*
in_1
,
hipStream_t
stream
)
{
int
M
=
*
in_0
;
int
N
=
*
in_1
;
std
::
vector
<
ck
::
index_t
>
i_inStrides
;
i_inStrides
.
push_back
(
N
);
i_inStrides
.
push_back
(
1
);
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
M
,
N
},
i_inStrides
,
std
::
vector
<
ck
::
index_t
>
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
0
,
1
},
i_inStrides
,
{
1
},
1e-05
,
static_cast
<
ck
::
half_t
*>
(
input
),
static_cast
<
ck
::
half_t
*>
(
gamma
),
static_cast
<
ck
::
half_t
*>
(
beta
),
static_cast
<
ck
::
half_t
*>
(
output
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
);
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! device_layernorm with the specified compilation parameters does "
"not support this Softmax problem"
);
};
std
::
string
instance_name
=
device_instance
.
GetTypeString
();
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
stream
,
false
});
return
;
}
struct
ProfilerMemoryPool
{
ProfilerMemoryPool
()
{
std
::
random_device
rd
;
gen
=
std
::
mt19937
(
rd
());
uniform_dist
=
std
::
uniform_int_distribution
<
int64_t
>
(
1
,
48964896
);
offsets
.
reserve
(
512
);
strides
.
reserve
(
512
);
copies
.
reserve
(
512
);
ptrs
.
reserve
(
512
);
}
~
ProfilerMemoryPool
()
{
for
(
int
i
=
0
;
i
<
ptrs
.
size
();
i
++
){
hipFree
(
ptrs
[
i
]);
}
}
template
<
typename
DType
>
DType
*
AllocateGaussianTensor
(
int64_t
size
)
{
size_t
length
=
size
*
sizeof
(
DType
);
DType
*
d_x
;
hipMalloc
(
&
d_x
,
length
);
float
mean
=
0.0
f
;
float
stddev
=
1.0
f
;
uint64_t
seed
=
uniform_dist
(
gen
);
rocrand_set_seed
(
generator
,
seed
);
rocrand_generate_normal
(
generator
,
reinterpret_cast
<
float
*>
(
d_x
),
size
,
mean
,
stddev
);
return
d_x
;
}
ck
::
half_t
*
AllocateHalfGaussianTensor
(
int64_t
size
)
{
return
reinterpret_cast
<
ck
::
half_t
*>
(
AllocateGaussianTensor
<
ck
::
half_t
>
(
size
));
}
int
AllocateHalfTensor
(
int64_t
size
,
int64_t
copy
)
{
offsets
.
push_back
(
0
);
strides
.
push_back
(
size
);
copies
.
push_back
(
copy
);
auto
ptr
=
AllocateHalfGaussianTensor
(
size
*
copy
);
ptrs
.
push_back
(
reinterpret_cast
<
void
*>
(
ptr
));
return
ptrs
.
size
()
-
1
;
}
ck
::
half_t
*
RequestHalfTensorByIdx
(
int
idx
)
{
auto
copy
=
copies
.
at
(
idx
);
auto
offset
=
offsets
.
at
(
idx
);
auto
stride
=
strides
.
at
(
idx
);
ck
::
half_t
*
ptr
=
reinterpret_cast
<
ck
::
half_t
*>
(
ptrs
.
at
(
idx
));
ptr
+=
offset
;
offset
+=
stride
;
if
(
offset
==
copy
*
stride
)
{
offset
=
0
;
}
offsets
[
idx
]
=
offset
;
return
ptr
;
}
std
::
vector
<
int64_t
>
offsets
;
std
::
vector
<
int64_t
>
strides
;
std
::
vector
<
int64_t
>
copies
;
std
::
vector
<
void
*>
ptrs
;
std
::
mt19937
gen
;
std
::
uniform_int_distribution
<
int64_t
>
uniform_dist
;
rocrand_generator
generator
;
};
// hack for DeviceMem linking error
// TODO fix this by making CK a header-only lib
// <<< hack begin
DeviceMem
::
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
void
*
DeviceMem
::
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
const
{
hipGetErrorString
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
const
{
hipGetErrorString
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
DeviceMem
::~
DeviceMem
()
{
hipGetErrorString
(
hipFree
(
mpDeviceBuf
));
}
struct
KernelTimerImpl
{
KernelTimerImpl
()
{
hipGetErrorString
(
hipEventCreate
(
&
mStart
));
hipGetErrorString
(
hipEventCreate
(
&
mEnd
));
}
~
KernelTimerImpl
()
{
hipGetErrorString
(
hipEventDestroy
(
mStart
));
hipGetErrorString
(
hipEventDestroy
(
mEnd
));
}
void
Start
()
{
hipGetErrorString
(
hipDeviceSynchronize
());
hipGetErrorString
(
hipEventRecord
(
mStart
,
nullptr
));
}
void
End
()
{
hipGetErrorString
(
hipEventRecord
(
mEnd
,
nullptr
));
hipGetErrorString
(
hipEventSynchronize
(
mEnd
));
}
float
GetElapsedTime
()
const
{
float
time
;
hipGetErrorString
(
hipEventElapsedTime
(
&
time
,
mStart
,
mEnd
));
return
time
;
}
hipEvent_t
mStart
,
mEnd
;
};
// >>> hack end
int
main
(
int
argc
,
char
**
argv
)
{
const
int64_t
in_0
=
std
::
stoi
(
argv
[
1
]);
const
int64_t
in_1
=
std
::
stoi
(
argv
[
2
]);
auto
memory_pool
=
std
::
make_unique
<
ProfilerMemoryPool
>
();
hipStream_t
stream
=
nullptr
;
int64_t
ptr_sz
=
in_0
*
in_1
;
int64_t
norm_dim
=
in_1
;
// TODO: special pool size for 8M L2 cache
// need to tune it for other devices
int64_t
mem_pool_sz
=
std
::
max
(
2
,
std
::
min
(
64
,
int
((
1
<<
23
)
/
ptr_sz
)));
memory_pool
->
AllocateHalfTensor
(
ptr_sz
,
mem_pool_sz
);
// in: index 0
memory_pool
->
AllocateHalfTensor
(
ptr_sz
,
mem_pool_sz
);
// out: index 1
memory_pool
->
AllocateHalfTensor
(
norm_dim
,
mem_pool_sz
);
// gamma: index 2
memory_pool
->
AllocateHalfTensor
(
norm_dim
,
mem_pool_sz
);
// beta: index 3
// warmup
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
layernorm_0
(
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
0
),
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
2
),
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
3
),
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
1
),
const_cast
<
int64_t
*>
(
&
in_0
),
const_cast
<
int64_t
*>
(
&
in_1
),
stream
);
}
// run
KernelTimerImpl
timer
;
timer
.
Start
();
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
layernorm_0
(
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
0
),
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
2
),
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
3
),
(
void
*
)
memory_pool
->
RequestHalfTensorByIdx
(
1
),
const_cast
<
int64_t
*>
(
&
in_0
),
const_cast
<
int64_t
*>
(
&
in_1
),
stream
);
}
timer
.
End
();
std
::
cout
<<
"WS:"
<<
GLOBAL_WORKSPACE_SIZE
<<
std
::
endl
;
std
::
cout
<<
"TIME:"
<<
timer
.
GetElapsedTime
()
<<
std
::
endl
;
}
\ No newline at end of file
python/AIT implementation/sample files/model-generate.h
0 → 100644
View file @
516bbdcb
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <string>
#include <unordered_map>
#include <math.h>
void
gemm_rrr_3
(
void
*
,
void
*
,
void
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
hipStream_t
);
#define CHECK_VECTOR_ACCESS(vector, idx) \
if
(
idx
>=
vector
.
size
())
{
\
throw
std
::
out_of_range
(
\
"[__func__]: index out of range, "
#
vector
".size()="
+
\
std
::
to_string
(
vector
.
size
())
+
", got "
+
std
::
to_string
(
idx
));
\
}
namespace
ait
{
namespace
{
void
DeviceCheckLastError
(
const
char
*
file
,
int
line
)
{
auto
device_error
=
GetLastError
();
if
(
device_error
!=
GetDeviceSuccess
())
{
std
::
string
msg
=
std
::
string
(
"Got error: "
)
+
GetLastErrorString
()
+
" enum: "
+
std
::
to_string
(
device_error
)
+
" at "
+
file
+
": "
+
std
::
to_string
(
line
);
LOG
(
ERROR
)
<<
msg
;
throw
std
::
runtime_error
(
msg
);
}
}
thread_local
bool
target_has_graph_mode
=
false
;
}
// namespace
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class
Model
{
public:
Model
(
size_t
blob_size
,
size_t
workspace_size
,
size_t
num_inputs
,
size_t
num_outputs
,
size_t
num_unbound_constants
,
uint8_t
*
constants
,
AITemplateAllocator
&
allocator
)
:
blob_
(
RAII_DeviceMalloc
(
blob_size
,
allocator
)),
workspace_
(
RAII_DeviceMalloc
(
workspace_size
,
allocator
)),
params_
(
num_inputs
+
num_outputs
+
num_unbound_constants
),
num_inputs_
(
num_inputs
),
num_outputs_
(
num_outputs
),
constants_
(
constants
)
{
dmlc
::
InitLogging
(
"aitemplate"
);
// TODO(xxx): render network name
LOG
(
INFO
)
<<
"Init AITemplate Runtime."
;
global_workspace_
=
static_cast
<
uint8_t
*>
(
workspace_
.
get
())
+
0
;
unique_workspace_
=
static_cast
<
uint8_t
*>
(
workspace_
.
get
());
DEVICE_CHECK
(
GetDevice
(
&
device_idx_
))
DEVICE_CHECK
(
CreateEvent
(
&
run_finished_
));
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
DEVICE_CHECK
(
cudaDeviceGetAttribute
(
&
max_smem_size_
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device_idx_
));
#endif
DEVICE_CHECK
(
GetDeviceProperties
(
&
device_properties_
,
device_idx_
));
DEVICE_CHECK
(
StreamCreate
(
&
graph_capture_stream_
,
/*non_blocking=*/
true
));
InitConstants
(
constants_
);
auto
*
blob_ptr
=
static_cast
<
uint8_t
*>
(
blob_
.
get
());
}
~
Model
()
{
if
(
run_finished_
!=
nullptr
)
{
DestroyEvent
(
run_finished_
);
}
if
(
graph_capture_stream_
!=
nullptr
)
{
StreamDestroy
(
graph_capture_stream_
);
}
if
(
graph_exec_
!=
nullptr
)
{
GraphExecDestroy
(
graph_exec_
);
}
}
Model
(
Model
&&
other
)
{
run_finished_
=
other
.
run_finished_
;
graph_exec_
=
other
.
graph_exec_
;
graph_capture_stream_
=
other
.
graph_capture_stream_
;
other
.
run_finished_
=
nullptr
;
other
.
graph_exec_
=
nullptr
;
other
.
graph_capture_stream_
=
nullptr
;
constants_
=
other
.
constants_
;
num_inputs_
=
other
.
num_inputs_
;
global_workspace_
=
other
.
global_workspace_
;
unique_workspace_
=
other
.
unique_workspace_
;
workspace_
=
std
::
move
(
other
.
workspace_
);
params_
=
std
::
move
(
other
.
params_
);
constant_name_to_ptr_
=
std
::
move
(
other
.
constant_name_to_ptr_
);
// Re-wire the pointers in the above 2 structures.
InitConstants
(
constants_
);
}
Model
&
operator
=
(
Model
&&
)
=
delete
;
Model
(
const
Model
&
)
=
delete
;
Model
&
operator
=
(
const
Model
&
)
=
delete
;
void
SetUpInputsOutputs
()
{
input_0
=
static_cast
<
decltype
(
input_0
)
>
(
params_
[
0
].
ptr
);
if
(
input_0
==
nullptr
)
{
throw
std
::
runtime_error
(
"Constant input_0 was not set! Set the value with set_constant."
);
}
input_1
=
static_cast
<
decltype
(
input_1
)
>
(
params_
[
1
].
ptr
);
if
(
input_1
==
nullptr
)
{
throw
std
::
runtime_error
(
"Constant input_1 was not set! Set the value with set_constant."
);
}
output_0
=
static_cast
<
decltype
(
output_0
)
>
(
params_
[
2
].
ptr
);
if
(
output_0
==
nullptr
)
{
throw
std
::
runtime_error
(
"Constant output_0 was not set! Set the value with set_constant."
);
}
}
void
DeviceToDeviceCopies
(
StreamType
stream
)
{
}
void
Run
(
StreamType
stream
,
bool
graph_mode
)
{
SetUpInputsOutputs
();
if
(
target_has_graph_mode
&&
graph_mode
)
{
RunAsGraph
(
stream
);
}
else
{
RunImpl
(
stream
);
}
DEVICE_CHECK
(
EventRecord
(
run_finished_
,
stream
));
}
void
RunImpl
(
StreamType
stream
)
{
gemm_rrr_3
(
input_0
,
input_1
,
output_0
,
&
input_0_dim_0
,
&
input_0_dim_1
,
&
input_1_dim_0
,
&
input_1_dim_1
,
&
input_0_dim_0
,
&
input_1_dim_1
,
stream
);
DeviceCheckLastError
(
__FILE__
,
__LINE__
);
DeviceToDeviceCopies
(
stream
);
}
bool
IsPending
()
{
auto
query
=
QueryEvent
(
run_finished_
);
if
(
query
==
GetDeviceNotReady
())
{
return
true
;
}
if
(
query
!=
GetDeviceSuccess
())
{
LOG
(
WARNING
)
<<
"Pending model run did not finish successfully. Error: "
<<
GetErrorString
(
query
);
}
return
false
;
}
void
WaitForCompletion
()
{
DEVICE_CHECK
(
EventSynchronize
(
run_finished_
));
}
size_t
NumInputs
()
const
{
return
num_inputs_
;
}
size_t
NumOutputs
()
const
{
return
num_outputs_
;
}
void
SetParam
(
const
void
*
src
,
size_t
param_idx
)
{
CHECK_VECTOR_ACCESS
(
params_
,
param_idx
)
// const_cast is not ideal here, but it is unfortunately
// necessary:
// 1) We store outputs and inputs in the same vector,
// and outputs cannot be const.
// 2) Most of the codegen is not const-correct (most ops
// require non-const pointers). So even if we put const
// pointers into params, a const_cast would be required
// somewhere else.
params_
[
param_idx
].
ptr
=
const_cast
<
void
*>
(
src
);
}
void
SetInput
(
const
void
*
src
,
const
AITemplateParamShape
&
shape
,
size_t
idx
)
{
SetInputShape
(
shape
,
idx
);
SetParam
(
src
,
idx
);
}
void
SetOutput
(
void
*
src
,
size_t
idx
)
{
SetParam
(
src
,
idx
+
num_inputs_
);
}
// Write the (possibly dynamic) output shape to the given pointer.
// Note that this should be called _after_ the shape inference in
// Run() is finished. output_shape_out should be able to store
// at least GetOutputMaximumShape(idx).size values.
void
GetOutputShape
(
size_t
idx
,
int64_t
*
output_shape_out
)
{
const
auto
param_idx
=
idx
+
num_inputs_
;
CHECK_VECTOR_ACCESS
(
params_
,
param_idx
);
const
auto
&
shape_ptrs
=
params_
[
param_idx
].
shape_ptrs
;
for
(
size_t
i
=
0
;
i
<
shape_ptrs
.
size
();
++
i
)
{
output_shape_out
[
i
]
=
shape_ptrs
[
i
].
GetValue
();
}
}
void
SetConstant
(
const
char
*
name
,
const
void
*
src
)
{
auto
it
=
constant_name_to_ptr_
.
find
(
name
);
if
(
it
==
constant_name_to_ptr_
.
end
())
{
throw
std
::
out_of_range
(
std
::
string
(
"Could not find constant "
)
+
name
);
}
const
void
**
ptr
=
it
->
second
;
*
ptr
=
src
;
}
private:
void
InitConstants
(
uint8_t
*
constants
)
{
params_
[
0
].
shape_ptrs
=
{
ParamDim
(
256
,
256
,
&
input_0_dim_0
),
ParamDim
(
128
,
128
,
&
input_0_dim_1
)};
params_
[
1
].
shape_ptrs
=
{
ParamDim
(
128
,
128
,
&
input_1_dim_0
),
ParamDim
(
32
,
32
,
&
input_1_dim_1
)};
params_
[
2
].
shape_ptrs
=
{
ParamDim
(
256
,
256
,
&
input_0_dim_0
),
ParamDim
(
32
,
32
,
&
input_1_dim_1
)};
}
void
SetInputShape
(
const
AITemplateParamShape
&
shape
,
size_t
idx
)
{
auto
&
param
=
params_
[
idx
];
if
(
shape
.
size
!=
param
.
shape_ptrs
.
size
())
{
throw
std
::
runtime_error
(
"[SetInputShape] Got wrong param shape for input "
+
std
::
to_string
(
idx
)
+
"; expected "
+
std
::
to_string
(
param
.
shape_ptrs
.
size
())
+
", got "
+
std
::
to_string
(
shape
.
size
));
}
for
(
size_t
i
=
0
;
i
<
param
.
shape_ptrs
.
size
();
++
i
)
{
param
.
shape_ptrs
[
i
].
SetValue
(
shape
.
shape_data
[
i
]);
}
}
DeviceError
EndCapture
(
GraphType
*
graph_ptr
)
{
auto
err
=
StreamEndCapture
(
graph_capture_stream_
,
graph_ptr
);
if
(
err
!=
GetDeviceSuccess
())
{
// If we can't take the stream out of capture mode, something is probably
// wrong with CUDA graph for this model (e.g. there might have been an
// illegal capture mode operation). Disable graph mode to avoid such issues
// in future iterations.
target_has_graph_mode
=
false
;
LOG
(
WARNING
)
<<
"Graph capture failed to end. Disabling graph mode."
;
return
err
;
}
return
GetDeviceSuccess
();
}
void
RunAsGraph
(
StreamType
stream
)
{
DEVICE_CHECK
(
StreamBeginCapture
(
graph_capture_stream_
,
/*global=*/
false
));
try
{
RunImpl
(
graph_capture_stream_
);
}
catch
(...)
{
GraphType
graph
;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture
(
&
graph
);
if
(
graph
!=
nullptr
&&
GraphDestroy
(
graph
)
!=
GetDeviceSuccess
())
{
LOG
(
WARNING
)
<<
"Graph destruction failed while handling exception! Memory will be leaked."
;
}
throw
;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto
graph
=
RAII_EndCaptureAndCreateGraph
(
[
this
](
GraphType
*
graph_ptr
){
return
EndCapture
(
graph_ptr
);
}
);
if
(
graph_exec_
==
nullptr
)
{
DEVICE_CHECK
(
GraphInstantiate
(
&
graph_exec_
,
graph
.
get
()));
}
else
if
(
GraphExecUpdate
(
graph_exec_
,
graph
.
get
())
!=
GetDeviceSuccess
())
{
// Consume the last cuda error, which may affect the next GraphExecLaunch
// call.
GetLastError
();
DEVICE_CHECK
(
GraphExecDestroy
(
graph_exec_
));
DEVICE_CHECK
(
GraphInstantiate
(
&
graph_exec_
,
graph
.
get
()));
}
DEVICE_CHECK
(
GraphExecLaunch
(
graph_exec_
,
stream
));
}
int
device_idx_
;
int
max_smem_size_
{
0
};
DevicePropertyType
device_properties_
;
// This event tracks when the inference is finished
// so that this Model may be reclaimed by its owning
// ModelContainer.
EventType
run_finished_
;
// A blob of memory used for storing intermediate tensors.
GPUPtr
blob_
;
// Memory for constants that were folded into the *.so. Unowned by Model,
// owned by ModelContainer.
// TODO: make this const. It can't be const right now because we derive
// tensor pointers from it, and no tensor pointers are const.
uint8_t
*
constants_
;
size_t
num_inputs_
;
size_t
num_outputs_
;
// The workspace blob is used as scratch memory. See
// _generate_workspace in memory planning for more information.
GPUPtr
workspace_
;
uint8_t
*
global_workspace_
{
nullptr
};
uint8_t
*
unique_workspace_
{
nullptr
};
class
ParamDim
{
public:
ParamDim
(
int64_t
lower_bound
,
int64_t
upper_bound
,
int64_t
*
value
)
:
lower_bound_
(
lower_bound
),
upper_bound_
(
upper_bound
),
value_
(
value
)
{}
void
SetValue
(
int64_t
new_value
)
{
if
(
new_value
<
lower_bound_
||
new_value
>
upper_bound_
)
{
throw
std
::
out_of_range
(
"[SetValue] Dimension got value out of bounds; expected value to be in ["
+
std
::
to_string
(
lower_bound_
)
+
", "
+
std
::
to_string
(
upper_bound_
)
+
"], but got "
+
std
::
to_string
(
new_value
)
);
}
*
value_
=
new_value
;
}
int64_t
GetValue
()
const
{
return
*
value_
;
}
private:
int64_t
lower_bound_
;
int64_t
upper_bound_
;
int64_t
*
value_
;
};
struct
ParamInfo
{
void
*
ptr
=
nullptr
;
// TODO add offset
const
char
*
name
;
std
::
vector
<
ParamDim
>
shape_ptrs
;
};
// Contains info for all tensors marked as inputs
// or outputs. The first num_inputs elements are the inputs.
// Constants are not included.
std
::
vector
<
ParamInfo
>
params_
;
GraphExecType
graph_exec_
=
nullptr
;
StreamType
graph_capture_stream_
;
std
::
unordered_map
<
std
::
string
,
const
void
**>
constant_name_to_ptr_
;
void
*
input_0
{
nullptr
};
void
*
input_1
{
nullptr
};
void
*
output_0
{
nullptr
};
int64_t
input_0_dim_0
{
256
};
int64_t
input_0_dim_1
{
128
};
int64_t
input_1_dim_0
{
128
};
int64_t
input_1_dim_1
{
32
};
};
}
// namespace ait
\ No newline at end of file
python/AIT implementation/sample files/model-generated.h
0 → 100644
View file @
516bbdcb
#pragma once
#include "logging.h"
#include "device_functions-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include "macros.h"
#include <algorithm>
#include <deque>
#include <string>
#include <unordered_map>
#include <math.h>
void
gemm_rrr_3
(
void
*
,
void
*
,
void
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
int64_t
*
,
hipStream_t
);
#define CHECK_VECTOR_ACCESS(vector, idx) \
if
(
idx
>=
vector
.
size
())
{
\
throw
std
::
out_of_range
(
\
"[__func__]: index out of range, "
#
vector
".size()="
+
\
std
::
to_string
(
vector
.
size
())
+
", got "
+
std
::
to_string
(
idx
));
\
}
namespace
ait
{
namespace
{
void
DeviceCheckLastError
(
const
char
*
file
,
int
line
)
{
auto
device_error
=
GetLastError
();
if
(
device_error
!=
GetDeviceSuccess
())
{
std
::
string
msg
=
std
::
string
(
"Got error: "
)
+
GetLastErrorString
()
+
" enum: "
+
std
::
to_string
(
device_error
)
+
" at "
+
file
+
": "
+
std
::
to_string
(
line
);
LOG
(
ERROR
)
<<
msg
;
throw
std
::
runtime_error
(
msg
);
}
}
thread_local
bool
target_has_graph_mode
=
false
;
}
// namespace
// Model is the class that actually performs inference. It owns memory for
// intermediate tensors and dynamic dimensions. Constants are owned by
// the model's owning container object, and input/output memory is owned
// by the user.
// Once an inference run has started, it is not safe to re-use the Model
// until the run has finished!
class
Model
{
public:
Model
(
size_t
blob_size
,
size_t
workspace_size
,
size_t
num_inputs
,
size_t
num_outputs
,
size_t
num_unbound_constants
,
uint8_t
*
constants
,
AITemplateAllocator
&
allocator
)
:
blob_
(
RAII_DeviceMalloc
(
blob_size
,
allocator
)),
workspace_
(
RAII_DeviceMalloc
(
workspace_size
,
allocator
)),
params_
(
num_inputs
+
num_outputs
+
num_unbound_constants
),
num_inputs_
(
num_inputs
),
num_outputs_
(
num_outputs
),
constants_
(
constants
)
{
dmlc
::
InitLogging
(
"aitemplate"
);
// TODO(xxx): render network name
LOG
(
INFO
)
<<
"Init AITemplate Runtime."
;
global_workspace_
=
static_cast
<
uint8_t
*>
(
workspace_
.
get
())
+
0
;
unique_workspace_
=
static_cast
<
uint8_t
*>
(
workspace_
.
get
());
DEVICE_CHECK
(
GetDevice
(
&
device_idx_
))
DEVICE_CHECK
(
CreateEvent
(
&
run_finished_
));
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))
DEVICE_CHECK
(
cudaDeviceGetAttribute
(
&
max_smem_size_
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device_idx_
));
#endif
DEVICE_CHECK
(
GetDeviceProperties
(
&
device_properties_
,
device_idx_
));
DEVICE_CHECK
(
StreamCreate
(
&
graph_capture_stream_
,
/*non_blocking=*/
true
));
InitConstants
(
constants_
);
auto
*
blob_ptr
=
static_cast
<
uint8_t
*>
(
blob_
.
get
());
}
~
Model
()
{
if
(
run_finished_
!=
nullptr
)
{
DestroyEvent
(
run_finished_
);
}
if
(
graph_capture_stream_
!=
nullptr
)
{
StreamDestroy
(
graph_capture_stream_
);
}
if
(
graph_exec_
!=
nullptr
)
{
GraphExecDestroy
(
graph_exec_
);
}
}
Model
(
Model
&&
other
)
{
run_finished_
=
other
.
run_finished_
;
graph_exec_
=
other
.
graph_exec_
;
graph_capture_stream_
=
other
.
graph_capture_stream_
;
other
.
run_finished_
=
nullptr
;
other
.
graph_exec_
=
nullptr
;
other
.
graph_capture_stream_
=
nullptr
;
constants_
=
other
.
constants_
;
num_inputs_
=
other
.
num_inputs_
;
global_workspace_
=
other
.
global_workspace_
;
unique_workspace_
=
other
.
unique_workspace_
;
workspace_
=
std
::
move
(
other
.
workspace_
);
params_
=
std
::
move
(
other
.
params_
);
constant_name_to_ptr_
=
std
::
move
(
other
.
constant_name_to_ptr_
);
// Re-wire the pointers in the above 2 structures.
InitConstants
(
constants_
);
}
Model
&
operator
=
(
Model
&&
)
=
delete
;
Model
(
const
Model
&
)
=
delete
;
Model
&
operator
=
(
const
Model
&
)
=
delete
;
void
SetUpInputsOutputs
()
{
input_0
=
static_cast
<
decltype
(
input_0
)
>
(
params_
[
0
].
ptr
);
if
(
input_0
==
nullptr
)
{
throw
std
::
runtime_error
(
"Constant input_0 was not set! Set the value with set_constant."
);
}
input_1
=
static_cast
<
decltype
(
input_1
)
>
(
params_
[
1
].
ptr
);
if
(
input_1
==
nullptr
)
{
throw
std
::
runtime_error
(
"Constant input_1 was not set! Set the value with set_constant."
);
}
output_0
=
static_cast
<
decltype
(
output_0
)
>
(
params_
[
2
].
ptr
);
if
(
output_0
==
nullptr
)
{
throw
std
::
runtime_error
(
"Constant output_0 was not set! Set the value with set_constant."
);
}
}
void
DeviceToDeviceCopies
(
StreamType
stream
)
{
}
void
Run
(
StreamType
stream
,
bool
graph_mode
)
{
SetUpInputsOutputs
();
if
(
target_has_graph_mode
&&
graph_mode
)
{
RunAsGraph
(
stream
);
}
else
{
RunImpl
(
stream
);
}
DEVICE_CHECK
(
EventRecord
(
run_finished_
,
stream
));
}
void
RunImpl
(
StreamType
stream
)
{
gemm_rrr_3
(
input_0
,
input_1
,
output_0
,
&
input_0_dim_0
,
&
input_0_dim_1
,
&
input_1_dim_0
,
&
input_1_dim_1
,
&
input_0_dim_0
,
&
input_1_dim_1
,
stream
);
DeviceCheckLastError
(
__FILE__
,
__LINE__
);
DeviceToDeviceCopies
(
stream
);
}
bool
IsPending
()
{
auto
query
=
QueryEvent
(
run_finished_
);
if
(
query
==
GetDeviceNotReady
())
{
return
true
;
}
if
(
query
!=
GetDeviceSuccess
())
{
LOG
(
WARNING
)
<<
"Pending model run did not finish successfully. Error: "
<<
GetErrorString
(
query
);
}
return
false
;
}
void
WaitForCompletion
()
{
DEVICE_CHECK
(
EventSynchronize
(
run_finished_
));
}
size_t
NumInputs
()
const
{
return
num_inputs_
;
}
size_t
NumOutputs
()
const
{
return
num_outputs_
;
}
void
SetParam
(
const
void
*
src
,
size_t
param_idx
)
{
CHECK_VECTOR_ACCESS
(
params_
,
param_idx
)
// const_cast is not ideal here, but it is unfortunately
// necessary:
// 1) We store outputs and inputs in the same vector,
// and outputs cannot be const.
// 2) Most of the codegen is not const-correct (most ops
// require non-const pointers). So even if we put const
// pointers into params, a const_cast would be required
// somewhere else.
params_
[
param_idx
].
ptr
=
const_cast
<
void
*>
(
src
);
}
void
SetInput
(
const
void
*
src
,
const
AITemplateParamShape
&
shape
,
size_t
idx
)
{
SetInputShape
(
shape
,
idx
);
SetParam
(
src
,
idx
);
}
void
SetOutput
(
void
*
src
,
size_t
idx
)
{
SetParam
(
src
,
idx
+
num_inputs_
);
}
// Write the (possibly dynamic) output shape to the given pointer.
// Note that this should be called _after_ the shape inference in
// Run() is finished. output_shape_out should be able to store
// at least GetOutputMaximumShape(idx).size values.
void
GetOutputShape
(
size_t
idx
,
int64_t
*
output_shape_out
)
{
const
auto
param_idx
=
idx
+
num_inputs_
;
CHECK_VECTOR_ACCESS
(
params_
,
param_idx
);
const
auto
&
shape_ptrs
=
params_
[
param_idx
].
shape_ptrs
;
for
(
size_t
i
=
0
;
i
<
shape_ptrs
.
size
();
++
i
)
{
output_shape_out
[
i
]
=
shape_ptrs
[
i
].
GetValue
();
}
}
void
SetConstant
(
const
char
*
name
,
const
void
*
src
)
{
auto
it
=
constant_name_to_ptr_
.
find
(
name
);
if
(
it
==
constant_name_to_ptr_
.
end
())
{
throw
std
::
out_of_range
(
std
::
string
(
"Could not find constant "
)
+
name
);
}
const
void
**
ptr
=
it
->
second
;
*
ptr
=
src
;
}
private:
void
InitConstants
(
uint8_t
*
constants
)
{
params_
[
0
].
shape_ptrs
=
{
ParamDim
(
256
,
256
,
&
input_0_dim_0
),
ParamDim
(
128
,
128
,
&
input_0_dim_1
)};
params_
[
1
].
shape_ptrs
=
{
ParamDim
(
128
,
128
,
&
input_1_dim_0
),
ParamDim
(
32
,
32
,
&
input_1_dim_1
)};
params_
[
2
].
shape_ptrs
=
{
ParamDim
(
256
,
256
,
&
input_0_dim_0
),
ParamDim
(
32
,
32
,
&
input_1_dim_1
)};
}
void
SetInputShape
(
const
AITemplateParamShape
&
shape
,
size_t
idx
)
{
auto
&
param
=
params_
[
idx
];
if
(
shape
.
size
!=
param
.
shape_ptrs
.
size
())
{
throw
std
::
runtime_error
(
"[SetInputShape] Got wrong param shape for input "
+
std
::
to_string
(
idx
)
+
"; expected "
+
std
::
to_string
(
param
.
shape_ptrs
.
size
())
+
", got "
+
std
::
to_string
(
shape
.
size
));
}
for
(
size_t
i
=
0
;
i
<
param
.
shape_ptrs
.
size
();
++
i
)
{
param
.
shape_ptrs
[
i
].
SetValue
(
shape
.
shape_data
[
i
]);
}
}
DeviceError
EndCapture
(
GraphType
*
graph_ptr
)
{
auto
err
=
StreamEndCapture
(
graph_capture_stream_
,
graph_ptr
);
if
(
err
!=
GetDeviceSuccess
())
{
// If we can't take the stream out of capture mode, something is probably
// wrong with CUDA graph for this model (e.g. there might have been an
// illegal capture mode operation). Disable graph mode to avoid such issues
// in future iterations.
target_has_graph_mode
=
false
;
LOG
(
WARNING
)
<<
"Graph capture failed to end. Disabling graph mode."
;
return
err
;
}
return
GetDeviceSuccess
();
}
void
RunAsGraph
(
StreamType
stream
)
{
DEVICE_CHECK
(
StreamBeginCapture
(
graph_capture_stream_
,
/*global=*/
false
));
try
{
RunImpl
(
graph_capture_stream_
);
}
catch
(...)
{
GraphType
graph
;
// No need to DEVICE_CHECK here, we want to see the original exception.
EndCapture
(
&
graph
);
if
(
graph
!=
nullptr
&&
GraphDestroy
(
graph
)
!=
GetDeviceSuccess
())
{
LOG
(
WARNING
)
<<
"Graph destruction failed while handling exception! Memory will be leaked."
;
}
throw
;
}
// The following function ends the capture and creates a graph
// inside a unique_ptr that cleans up it when it goes out of scope.
// Note that it throws an exception if EndCapture fails.
auto
graph
=
RAII_EndCaptureAndCreateGraph
(
[
this
](
GraphType
*
graph_ptr
){
return
EndCapture
(
graph_ptr
);
}
);
if
(
graph_exec_
==
nullptr
)
{
DEVICE_CHECK
(
GraphInstantiate
(
&
graph_exec_
,
graph
.
get
()));
}
else
if
(
GraphExecUpdate
(
graph_exec_
,
graph
.
get
())
!=
GetDeviceSuccess
())
{
// Consume the last cuda error, which may affect the next GraphExecLaunch
// call.
GetLastError
();
DEVICE_CHECK
(
GraphExecDestroy
(
graph_exec_
));
DEVICE_CHECK
(
GraphInstantiate
(
&
graph_exec_
,
graph
.
get
()));
}
DEVICE_CHECK
(
GraphExecLaunch
(
graph_exec_
,
stream
));
}
int
device_idx_
;
int
max_smem_size_
{
0
};
DevicePropertyType
device_properties_
;
// This event tracks when the inference is finished
// so that this Model may be reclaimed by its owning
// ModelContainer.
EventType
run_finished_
;
// A blob of memory used for storing intermediate tensors.
GPUPtr
blob_
;
// Memory for constants that were folded into the *.so. Unowned by Model,
// owned by ModelContainer.
// TODO: make this const. It can't be const right now because we derive
// tensor pointers from it, and no tensor pointers are const.
uint8_t
*
constants_
;
size_t
num_inputs_
;
size_t
num_outputs_
;
// The workspace blob is used as scratch memory. See
// _generate_workspace in memory planning for more information.
GPUPtr
workspace_
;
uint8_t
*
global_workspace_
{
nullptr
};
uint8_t
*
unique_workspace_
{
nullptr
};
class
ParamDim
{
public:
ParamDim
(
int64_t
lower_bound
,
int64_t
upper_bound
,
int64_t
*
value
)
:
lower_bound_
(
lower_bound
),
upper_bound_
(
upper_bound
),
value_
(
value
)
{}
void
SetValue
(
int64_t
new_value
)
{
if
(
new_value
<
lower_bound_
||
new_value
>
upper_bound_
)
{
throw
std
::
out_of_range
(
"[SetValue] Dimension got value out of bounds; expected value to be in ["
+
std
::
to_string
(
lower_bound_
)
+
", "
+
std
::
to_string
(
upper_bound_
)
+
"], but got "
+
std
::
to_string
(
new_value
)
);
}
*
value_
=
new_value
;
}
int64_t
GetValue
()
const
{
return
*
value_
;
}
private:
int64_t
lower_bound_
;
int64_t
upper_bound_
;
int64_t
*
value_
;
};
struct
ParamInfo
{
void
*
ptr
=
nullptr
;
// TODO add offset
const
char
*
name
;
std
::
vector
<
ParamDim
>
shape_ptrs
;
};
// Contains info for all tensors marked as inputs
// or outputs. The first num_inputs elements are the inputs.
// Constants are not included.
std
::
vector
<
ParamInfo
>
params_
;
GraphExecType
graph_exec_
=
nullptr
;
StreamType
graph_capture_stream_
;
std
::
unordered_map
<
std
::
string
,
const
void
**>
constant_name_to_ptr_
;
void
*
input_0
{
nullptr
};
void
*
input_1
{
nullptr
};
void
*
output_0
{
nullptr
};
int64_t
input_0_dim_0
{
256
};
int64_t
input_0_dim_1
{
128
};
int64_t
input_1_dim_0
{
128
};
int64_t
input_1_dim_1
{
32
};
};
}
// namespace ait
\ No newline at end of file
python/AIT implementation/sample files/model_container.cpp
0 → 100644
View file @
516bbdcb
#include "model_container.h"
#include "device_functions-generated.h"
#include "raii_wrapper.h"
namespace
ait
{
ModelContainer
::
ModelContainer
(
size_t
num_models
,
size_t
blob_size
,
size_t
workspace_size
,
size_t
num_inputs
,
size_t
num_outputs
,
size_t
num_unbound_constants
,
size_t
params_size
,
AITemplateAllocator
&
allocator
)
:
ModelContainerBase
(
num_inputs
,
num_outputs
,
num_unbound_constants
,
params_size
,
allocator
),
allocator_
(
allocator
),
num_inputs_
(
num_inputs
),
num_outputs_
(
num_outputs
)
{
if
(
num_models
==
0
)
{
throw
std
::
runtime_error
(
"Number of models must be positive"
);
}
models_
.
reserve
(
num_models
);
available_models_
.
reserve
(
num_models
);
for
(
size_t
i
=
0
;
i
<
num_models
;
++
i
)
{
models_
.
emplace_back
(
blob_size
,
workspace_size
,
num_inputs
,
num_outputs
,
num_unbound_constants
,
static_cast
<
uint8_t
*>
(
constants_
.
get
()),
allocator
);
available_models_
.
push_back
(
&
models_
.
back
());
}
}
void
ModelContainer
::
Run
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
sync
,
bool
graph_mode
,
int64_t
**
output_shapes_out
)
{
auto
*
model
=
GetAvailableModel
();
try
{
PrepareForRun
(
model
,
inputs
,
num_inputs
,
outputs
,
num_outputs
);
model
->
Run
(
stream
,
graph_mode
);
}
catch
(...)
{
std
::
lock_guard
lk
(
models_mutex_
);
available_models_
.
push_back
(
model
);
throw
;
}
if
(
output_shapes_out
)
{
for
(
size_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
auto
*
out_shape
=
output_shapes_out
[
i
];
model
->
GetOutputShape
(
i
,
out_shape
);
}
}
{
std
::
lock_guard
lk
(
models_mutex_
);
pending_models_
.
push_back
(
model
);
}
pending_models_available_
.
notify_one
();
if
(
sync
)
{
StreamSynchronize
(
stream
);
}
}
void
ModelContainer
::
RunWithOutputsOnHost
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
graph_mode
,
int64_t
**
output_shapes_out
)
{
std
::
vector
<
std
::
pair
<
GPUPtr
,
size_t
>>
owned_outputs_ptrs
;
std
::
vector
<
AITData
>
owned_outputs
;
owned_outputs_ptrs
.
reserve
(
num_outputs
);
owned_outputs
.
reserve
(
num_outputs
);
for
(
size_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
size_t
num_bytes
=
MaxOutputStorageBytes
(
i
);
owned_outputs_ptrs
.
emplace_back
(
RAII_DeviceMalloc
(
num_bytes
,
allocator_
),
num_bytes
);
owned_outputs
.
emplace_back
(
owned_outputs_ptrs
.
back
().
first
.
get
(),
outputs
[
i
].
shape
,
outputs
[
i
].
dtype
);
}
Run
(
inputs
,
num_inputs
,
owned_outputs
.
data
(),
num_outputs
,
stream
,
/*sync=*/
false
,
graph_mode
,
output_shapes_out
);
for
(
size_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
auto
&
owned_output
=
owned_outputs_ptrs
[
i
];
auto
&
ptr
=
owned_output
.
first
;
auto
num_bytes
=
owned_output
.
second
;
DEVICE_CHECK
(
CopyToHost
(
outputs
[
i
].
ptr
,
ptr
.
get
(),
num_bytes
,
stream
));
}
DEVICE_CHECK
(
StreamSynchronize
(
stream
));
}
float
ModelContainer
::
Benchmark
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
graph_mode
,
size_t
count
,
size_t
num_threads
,
bool
use_unique_stream_per_thread
,
int64_t
**
output_shapes_out
)
{
if
(
num_threads
==
0
)
{
num_threads
=
std
::
thread
::
hardware_concurrency
();
}
if
(
num_threads
==
1
)
{
return
BenchmarkImpl
(
inputs
,
num_inputs
,
outputs
,
num_outputs
,
stream
,
graph_mode
,
count
,
output_shapes_out
)
/
count
;
}
// Clone the outputs, each thread needs its own set
std
::
vector
<
std
::
vector
<
GPUPtr
>>
per_thread_outputs_ptrs
;
std
::
vector
<
std
::
vector
<
AITData
>>
per_thread_outputs
;
std
::
vector
<
StreamPtr
>
per_thread_streams
;
per_thread_outputs_ptrs
.
reserve
(
num_threads
-
1
);
per_thread_outputs
.
reserve
(
num_threads
-
1
);
if
(
use_unique_stream_per_thread
)
{
per_thread_streams
.
reserve
(
num_threads
);
for
(
size_t
i
=
0
;
i
<
num_threads
;
++
i
)
{
per_thread_streams
.
push_back
(
RAII_StreamCreate
(
/*non_blocking=*/
true
));
}
}
for
(
size_t
i
=
1
;
i
<
num_threads
;
++
i
)
{
std
::
vector
<
GPUPtr
>
cloned_outputs_ptrs
;
std
::
vector
<
AITData
>
cloned_outputs
;
cloned_outputs_ptrs
.
reserve
(
num_outputs
);
cloned_outputs
.
reserve
(
num_outputs
);
for
(
size_t
j
=
0
;
j
<
num_outputs
;
++
j
)
{
size_t
num_bytes
=
MaxOutputStorageBytes
(
j
);
cloned_outputs_ptrs
.
emplace_back
(
RAII_DeviceMalloc
(
num_bytes
,
allocator_
));
auto
*
new_pointer
=
cloned_outputs_ptrs
.
back
().
get
();
DEVICE_CHECK
(
DeviceToDeviceCopy
(
new_pointer
,
outputs
[
j
].
ptr
,
num_bytes
,
stream
));
cloned_outputs
.
emplace_back
(
new_pointer
,
outputs
[
j
].
shape
,
outputs
[
j
].
dtype
);
}
per_thread_outputs_ptrs
.
push_back
(
std
::
move
(
cloned_outputs_ptrs
));
per_thread_outputs
.
push_back
(
std
::
move
(
cloned_outputs
));
}
DEVICE_CHECK
(
StreamSynchronize
(
stream
));
auto
get_stream
=
[
stream
,
use_unique_stream_per_thread
,
&
per_thread_streams
](
size_t
thread_idx
)
{
if
(
!
use_unique_stream_per_thread
)
{
return
stream
;
}
return
per_thread_streams
[
thread_idx
].
get
();
};
auto
thread_func
=
[
&
](
size_t
thread_idx
)
{
AITData
*
thread_outputs
=
thread_idx
==
0
?
outputs
:
per_thread_outputs
[
thread_idx
-
1
].
data
();
StreamType
thread_stream
=
get_stream
(
thread_idx
);
auto
*
thread_output_shapes_out
=
thread_idx
==
0
?
output_shapes_out
:
nullptr
;
return
BenchmarkImpl
(
inputs
,
num_inputs
,
thread_outputs
,
num_outputs
,
thread_stream
,
graph_mode
,
count
,
thread_output_shapes_out
);
};
std
::
vector
<
std
::
future
<
float
>>
futures
;
futures
.
reserve
(
num_threads
);
for
(
size_t
i
=
0
;
i
<
num_threads
;
++
i
)
{
futures
.
push_back
(
std
::
async
(
std
::
launch
::
async
,
thread_func
,
i
));
}
auto
max_time
=
std
::
accumulate
(
futures
.
begin
(),
futures
.
end
(),
0.
f
,
[](
float
cur_val
,
auto
&
future
)
{
return
std
::
max
(
future
.
get
(),
cur_val
);
});
// Verify that all the outputs are the same
for
(
size_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
auto
output_size
=
MaxOutputStorageBytes
(
i
);
auto
output_host
=
std
::
make_unique
<
uint8_t
[]
>
(
output_size
);
// NB: technically, we don't have to copy to host here, but using
// std::memcmp is easier than writing a kernel that does comparisons
// for both backends, and performance is not important here.
DEVICE_CHECK
(
CopyToHost
(
output_host
.
get
(),
outputs
[
i
].
ptr
,
output_size
,
stream
));
DEVICE_CHECK
(
StreamSynchronize
(
stream
));
for
(
size_t
thread_idx
=
1
;
thread_idx
<
num_threads
;
++
thread_idx
)
{
auto
*
thread_output
=
per_thread_outputs
[
thread_idx
-
1
][
i
].
ptr
;
auto
thread_output_host
=
std
::
make_unique
<
uint8_t
[]
>
(
output_size
);
auto
thread_stream
=
get_stream
(
thread_idx
);
DEVICE_CHECK
(
CopyToHost
(
thread_output_host
.
get
(),
thread_output
,
output_size
,
thread_stream
));
DEVICE_CHECK
(
StreamSynchronize
(
thread_stream
));
if
(
std
::
memcmp
(
output_host
.
get
(),
thread_output_host
.
get
(),
output_size
))
{
throw
std
::
runtime_error
(
"Output "
+
std
::
to_string
(
i
)
+
" did not match for a spawned thread!"
);
}
}
}
auto
total_num_iters
=
num_threads
*
count
;
return
max_time
/
total_num_iters
;
}
void
ModelContainer
::
SetConstant
(
const
char
*
name
,
const
AITData
&
tensor
)
{
auto
it
=
unbound_constant_name_to_idx_
.
find
(
name
);
if
(
it
==
unbound_constant_name_to_idx_
.
end
())
{
// TODO make this an exception after we fix the CMF benchmarks
LOG
(
ERROR
)
<<
"Constant "
<<
name
<<
" not found"
;
return
;
}
auto
constant_idx
=
it
->
second
+
num_inputs_
+
num_outputs_
;
ValidateDtype
(
tensor
.
dtype
,
constant_idx
);
CHECK_VECTOR_ACCESS
(
max_param_storage_bytes_
,
constant_idx
)
auto
expected_num_bytes
=
max_param_storage_bytes_
[
constant_idx
];
auto
actual_num_bytes
=
tensor
.
shape
.
Numel
()
*
AITemplateDtypeSizeBytes
(
tensor
.
dtype
);
if
(
expected_num_bytes
!=
actual_num_bytes
)
{
throw
std
::
runtime_error
(
std
::
string
(
"SetConstant did not recieve correct number of bytes for constant "
)
+
name
+
": expected "
+
std
::
to_string
(
expected_num_bytes
)
+
" but got "
+
std
::
to_string
(
actual_num_bytes
)
+
". Check that the provided tensor's shape is correct."
);
}
auto
*
src
=
tensor
.
ptr
;
for
(
auto
&
model
:
models_
)
{
model
.
SetConstant
(
name
,
src
);
}
}
size_t
ModelContainer
::
NumInputs
()
const
{
return
num_inputs_
;
}
const
char
*
ModelContainer
::
InputName
(
size_t
input_idx
)
const
{
CHECK_VECTOR_ACCESS
(
param_names_
,
input_idx
)
return
param_names_
[
input_idx
];
}
size_t
ModelContainer
::
NumOutputs
()
const
{
return
num_outputs_
;
}
const
char
*
ModelContainer
::
OutputName
(
size_t
output_idx
)
const
{
auto
idx
=
output_idx
+
num_inputs_
;
CHECK_VECTOR_ACCESS
(
param_names_
,
idx
)
return
param_names_
[
idx
];
}
AITemplateParamShape
ModelContainer
::
MaxOutputShape
(
size_t
output_idx
)
const
{
auto
idx
=
output_idx
+
num_inputs_
;
CHECK_VECTOR_ACCESS
(
max_param_shapes_
,
idx
)
auto
&
out_shape
=
max_param_shapes_
[
idx
];
return
AITemplateParamShape
{
out_shape
.
data
(),
out_shape
.
size
()};
}
AITemplateDtype
ModelContainer
::
OutputDtype
(
size_t
output_idx
)
const
{
auto
idx
=
output_idx
+
num_inputs_
;
CHECK_VECTOR_ACCESS
(
param_dtypes_
,
idx
)
return
param_dtypes_
[
idx
];
}
size_t
ModelContainer
::
MaxOutputStorageBytes
(
size_t
output_idx
)
const
{
auto
idx
=
output_idx
+
num_inputs_
;
CHECK_VECTOR_ACCESS
(
max_param_storage_bytes_
,
idx
)
return
max_param_storage_bytes_
[
idx
];
}
void
ModelContainer
::
PrepareForRun
(
Model
*
model
,
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
)
{
if
(
num_inputs
!=
num_inputs_
)
{
auto
msg
=
"Got wrong number of inputs; expected "
+
std
::
to_string
(
num_inputs_
)
+
", got "
+
std
::
to_string
(
num_inputs
);
throw
std
::
runtime_error
(
std
::
move
(
msg
));
}
if
(
num_inputs
>
0
&&
inputs
==
nullptr
)
{
throw
std
::
runtime_error
(
"inputs cannot be null"
);
}
if
(
num_outputs
!=
num_outputs_
)
{
auto
msg
=
"Got wrong number of outputs; expected "
+
std
::
to_string
(
num_outputs_
)
+
", got "
+
std
::
to_string
(
num_outputs
);
throw
std
::
runtime_error
(
std
::
move
(
msg
));
}
if
(
num_outputs
>
0
&&
outputs
==
nullptr
)
{
throw
std
::
runtime_error
(
"outputs cannot be null"
);
}
for
(
size_t
i
=
0
;
i
<
num_inputs_
;
++
i
)
{
auto
&
input
=
inputs
[
i
];
ValidateDtype
(
input
.
dtype
,
i
);
model
->
SetInput
(
input
.
ptr
,
input
.
shape
,
i
);
}
for
(
size_t
i
=
0
;
i
<
num_outputs_
;
++
i
)
{
auto
&
output
=
outputs
[
i
];
ValidateDtype
(
output
.
dtype
,
i
+
num_inputs_
);
model
->
SetOutput
(
output
.
ptr
,
i
);
}
}
Model
*
ModelContainer
::
GetAvailableModel
()
{
std
::
unique_lock
lk
(
models_mutex_
);
if
(
available_models_
.
empty
())
{
ReclaimFinishedModels
(
lk
);
}
auto
*
result
=
available_models_
.
back
();
available_models_
.
pop_back
();
return
result
;
}
void
ModelContainer
::
ReclaimFinishedModels
(
std
::
unique_lock
<
std
::
mutex
>&
lk
)
{
// Put any complete models at the end
auto
it
=
std
::
stable_partition
(
pending_models_
.
begin
(),
pending_models_
.
end
(),
[](
Model
*
m
)
{
return
m
->
IsPending
();
});
if
(
it
!=
pending_models_
.
end
())
{
// Move all available models to the pool.
available_models_
.
insert
(
available_models_
.
end
(),
it
,
pending_models_
.
end
());
pending_models_
.
erase
(
it
,
pending_models_
.
end
());
return
;
}
pending_models_available_
.
wait
(
lk
,
[
this
]()
{
return
!
pending_models_
.
empty
();
});
// There are no available workspaces! We have to wait on one.
auto
*
model
=
pending_models_
.
front
();
pending_models_
.
pop_front
();
lk
.
unlock
();
try
{
model
->
WaitForCompletion
();
}
catch
(...)
{
lk
.
lock
();
available_models_
.
push_back
(
model
);
throw
;
}
lk
.
lock
();
available_models_
.
push_back
(
model
);
}
void
ModelContainer
::
ValidateDtype
(
AITemplateDtype
dtype
,
size_t
idx
)
const
{
CHECK_VECTOR_ACCESS
(
param_dtypes_
,
idx
)
if
(
dtype
!=
param_dtypes_
[
idx
])
{
auto
GetEnumString
=
[](
auto
dtype
)
{
switch
(
dtype
)
{
case
AITemplateDtype
::
kUnset
:
return
"kUnset"
;
case
AITemplateDtype
::
kHalf
:
return
"kHalf"
;
case
AITemplateDtype
::
kFloat
:
return
"kFloat"
;
case
AITemplateDtype
::
kInt
:
return
"kInt"
;
case
AITemplateDtype
::
kLong
:
return
"kLong"
;
default:
return
"unknown"
;
}
};
throw
std
::
runtime_error
(
"Got wrong dtype for param "
+
std
::
to_string
(
idx
)
+
"; expected "
+
GetEnumString
(
param_dtypes_
[
idx
])
+
", got "
+
GetEnumString
(
dtype
));
}
}
float
ModelContainer
::
BenchmarkImpl
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
graph_mode
,
size_t
count
,
int64_t
**
output_shapes_out
)
{
auto
*
model
=
GetAvailableModel
();
float
runtime_ms
=
0.
;
auto
start_event
=
RAII_CreateEvent
();
auto
end_event
=
RAII_CreateEvent
();
try
{
PrepareForRun
(
model
,
inputs
,
num_inputs
,
outputs
,
num_outputs
);
DEVICE_CHECK
(
EventRecord
(
start_event
.
get
(),
stream
));
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
model
->
Run
(
stream
,
graph_mode
);
}
}
catch
(...)
{
std
::
lock_guard
lk
(
models_mutex_
);
available_models_
.
push_back
(
model
);
throw
;
}
if
(
output_shapes_out
)
{
for
(
size_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
auto
*
out_shape
=
output_shapes_out
[
i
];
model
->
GetOutputShape
(
i
,
out_shape
);
}
}
// Push the model back into the pool before synchronizing the event
// to exercise the concurrency code
{
std
::
lock_guard
lk
(
models_mutex_
);
pending_models_
.
push_back
(
model
);
}
pending_models_available_
.
notify_one
();
DEVICE_CHECK
(
EventRecord
(
end_event
.
get
(),
stream
));
DEVICE_CHECK
(
EventSynchronize
(
end_event
.
get
()));
DEVICE_CHECK
(
EventElapsedTime
(
&
runtime_ms
,
start_event
.
get
(),
end_event
.
get
()));
LOG
(
INFO
)
<<
"Benchmark runtime ms/iter: "
<<
runtime_ms
/
count
;
return
runtime_ms
;
}
}
// namespace ait
python/AIT implementation/sample files/model_container.h
0 → 100644
View file @
516bbdcb
#pragma once
#include "model-generated.h"
#include "model_interface.h"
#include "raii_wrapper.h"
#include <condition_variable>
#include <cstring>
#include <future>
#include <mutex>
#include <numeric>
#include <unordered_map>
namespace
ait
{
// ModelContainer inherits from this class; its implementation is
// generated at compilation time. Most of the ModelContainer
// logic does not need codegen; anything that does should be put
// into this class instead.
class
ModelContainerBase
{
public:
ModelContainerBase
(
size_t
num_inputs
,
size_t
num_outputs
,
size_t
num_unbound_constants
,
size_t
params_size
,
AITemplateAllocator
&
allocator
);
protected:
// The set of unbound constants/weights/parameters. These are constants which
// have no value at compile time and do not participate in constant folding.
// They must be set via SetConstant prior to inference.
std
::
unordered_map
<
std
::
string
,
size_t
>
unbound_constant_name_to_idx_
;
// a single piece of memory for all constants
GPUPtr
constants_
;
// size of the containers below: # inputs + # outputs + # unbound constants.
size_t
num_params_
;
// These entries correspond to inputs/outputs/unbound constants in order;
// inputs first, then outputs, then constants.
std
::
vector
<
const
char
*>
param_names_
;
std
::
vector
<
std
::
vector
<
int64_t
>>
max_param_shapes_
;
std
::
vector
<
AITemplateDtype
>
param_dtypes_
;
// NB: technically these could be derived from both the max shape and
// the dytpe, but it's easier to just cache them.
std
::
vector
<
size_t
>
max_param_storage_bytes_
;
std
::
vector
<
size_t
>
max_param_numel_
;
};
// This creates a new ModelContainer; its implementation is also
// codegened (the parameters passed to the ctor are determined
// at compilation time)
class
ModelContainer
;
ModelContainer
*
CreateModelContainer
(
size_t
num_runtimes
,
AITemplateAllocator
&
allocator
);
// Each ModelContainer contains num_models Models. Inference runs
// can be started by invoking Run() with lists of pre-allocated
// input/output tensors. GetOutputMaximumShape() can be used to
// determine how much memory is required for each output.
//
// If there are N tensors marked with is_output=True,
// the user will always be expected to pass N output pointers -
// extra copies will occur if the outputs are views of constants,
// inputs, or other outputs in this case to avoid surprises.
//
// Use stream = nullptr for default stream. ModelContainer/Model does not
// create or own any stream. The user is expected to create and manage streams.
//
// We can support at most num_models concurrent inferences.
// Run() takes a stream to run the inference on. For example,
// to start up two inferences on different streams concurrently,
// we can do this:
//
// model_container.Run(inputs0, num_inputs, outputs0, num_ouputs, stream0, ...);
// model_container.Run(inputs1, num_inputs, outputs1, num_ouputs, stream1, ...);
// StreamSynchronize(stream0);
// StreamSynchronize(stream1);
//
// Note that if there are no models available for inference, Run() will block
// until one becomes available.
//
// ModelContainer optionally takes an allocator argument, which it will use to
// allocate the space for the buffers used for intermediate tensors and
// constants. If it is nullptr, the default allocator will be used (e.g. just
// {cuda/hip}{Malloc/Free}).
// Important: we assume that the allocator lives until the ModelContainer is
// destroyed. The default allocator has a static lifetime.
class
ModelContainer
:
ModelContainerBase
{
public:
ModelContainer
(
size_t
num_models
,
size_t
blob_size
,
size_t
workspace_size
,
size_t
num_inputs
,
size_t
num_outputs
,
size_t
num_unbound_constants
,
size_t
params_size
,
AITemplateAllocator
&
allocator
);
void
Run
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
sync
,
bool
graph_mode
,
int64_t
**
output_shapes_out
);
void
RunWithOutputsOnHost
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
graph_mode
,
int64_t
**
output_shapes_out
);
float
Benchmark
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
graph_mode
,
size_t
count
,
size_t
num_threads
,
bool
use_unique_stream_per_thread
,
int64_t
**
output_shapes_out
);
void
SetConstant
(
const
char
*
name
,
const
AITData
&
tensor
);
size_t
NumInputs
()
const
;
size_t
NumOutputs
()
const
;
const
char
*
InputName
(
size_t
input_idx
)
const
;
const
char
*
OutputName
(
size_t
output_idx
)
const
;
AITemplateParamShape
MaxOutputShape
(
size_t
output_idx
)
const
;
AITemplateDtype
OutputDtype
(
size_t
output_idx
)
const
;
size_t
MaxOutputStorageBytes
(
size_t
output_idx
)
const
;
size_t
GetNumRuntimes
()
const
{
return
models_
.
size
();
}
private:
void
PrepareForRun
(
Model
*
model
,
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
);
Model
*
GetAvailableModel
();
void
ReclaimFinishedModels
(
std
::
unique_lock
<
std
::
mutex
>&
lk
);
void
ValidateDtype
(
AITemplateDtype
dtype
,
size_t
idx
)
const
;
float
BenchmarkImpl
(
const
AITData
*
inputs
,
size_t
num_inputs
,
AITData
*
outputs
,
size_t
num_outputs
,
StreamType
stream
,
bool
graph_mode
,
size_t
count
,
int64_t
**
output_shapes_out
);
AITemplateAllocator
&
allocator_
;
std
::
vector
<
Model
>
models_
;
std
::
vector
<
Model
*>
available_models_
;
std
::
deque
<
Model
*>
pending_models_
;
// Guards accesses to available/pending models.
std
::
mutex
models_mutex_
;
// Notified whenever a model is put into pending_models_.
std
::
condition_variable
pending_models_available_
;
size_t
num_inputs_
;
size_t
num_outputs_
;
};
}
// namespace ait
python/AIT implementation/sample files/model_container_base.cpp
0 → 100644
View file @
516bbdcb
#include "model_container.h"
#include "owned_constants.h"
namespace
ait
{
namespace
{
// Contains the metadata for each constant.
constexpr
std
::
array
<
ConstantInfo
,
0
>
owned_constants
=
{
};
}
// namespace
ModelContainerBase
::
ModelContainerBase
(
size_t
num_inputs
,
size_t
num_outputs
,
size_t
num_unbound_constants
,
size_t
params_size
,
AITemplateAllocator
&
allocator
)
:
constants_
(
RAII_DeviceMalloc
(
params_size
,
allocator
)),
num_params_
(
num_inputs
+
num_outputs
+
num_unbound_constants
),
param_names_
(
num_params_
),
param_dtypes_
(
num_params_
),
max_param_shapes_
(
num_params_
),
max_param_numel_
(
num_params_
),
max_param_storage_bytes_
(
num_params_
)
{
param_names_
[
0
]
=
"input_0"
;
param_names_
[
1
]
=
"input_1"
;
param_names_
[
2
]
=
"output_0"
;
param_dtypes_
[
0
]
=
AITemplateDtype
::
kHalf
;
param_dtypes_
[
1
]
=
AITemplateDtype
::
kHalf
;
param_dtypes_
[
2
]
=
AITemplateDtype
::
kHalf
;
max_param_shapes_
[
0
]
=
{
256
,
128
};
max_param_shapes_
[
1
]
=
{
128
,
32
};
max_param_shapes_
[
2
]
=
{
256
,
32
};
for
(
size_t
i
=
0
;
i
<
num_params_
;
++
i
)
{
max_param_numel_
[
i
]
=
std
::
accumulate
(
max_param_shapes_
[
i
].
begin
(),
max_param_shapes_
[
i
].
end
(),
1
,
std
::
multiplies
<
int64_t
>
()
);
max_param_storage_bytes_
[
i
]
=
max_param_numel_
[
i
]
*
AITemplateDtypeSizeBytes
(
param_dtypes_
[
i
]);
}
auto
*
constants_ptr
=
static_cast
<
uint8_t
*>
(
constants_
.
get
());
const
auto
binary_constants_bin_size
=
static_cast
<
size_t
>
(
_binary_constants_bin_end
-
_binary_constants_bin_start
);
for
(
auto
&
constant_info
:
owned_constants
)
{
auto
*
dst
=
constants_ptr
+
constant_info
.
internal_offset
;
if
(
constant_info
.
data_offset
+
constant_info
.
num_bytes
>
binary_constants_bin_size
)
{
throw
std
::
runtime_error
(
std
::
string
(
"Copying constant "
)
+
constant_info
.
name
+
" would overflow constant buffer"
);
}
DEVICE_CHECK
(
CopyToDevice
(
dst
,
_binary_constants_bin_start
+
constant_info
.
data_offset
,
constant_info
.
num_bytes
));
}
}
ModelContainer
*
CreateModelContainer
(
size_t
num_runtimes
,
AITemplateAllocator
&
allocator
)
{
// num_runtimes, blob_size, workspace_size, num_inputs, num_outputs, num_unbound_constants, param_size, allocator
return
new
ModelContainer
(
num_runtimes
,
90112
,
0
,
2
,
1
,
0
,
0
,
allocator
);
}
}
// namespace ait
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment