Commit 3133fd79 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent d7ea085c
...@@ -79,7 +79,7 @@ struct ck_gemm_scale_bias_softmax_gemm ...@@ -79,7 +79,7 @@ struct ck_gemm_scale_bias_softmax_gemm
auto b1 = inputs[2]; auto b1 = inputs[2];
for(const auto& input : inputs) for(const auto& input : inputs)
{ {
//std::cout << input << std::endl; // std::cout << input << std::endl;
check_gemm_shape(input); check_gemm_shape(input);
} }
return op.compute_shape({op.compute_shape({a, b}), b1}); return op.compute_shape({op.compute_shape({a, b}), b1});
...@@ -158,10 +158,13 @@ struct find_ck_gemm_scale_bias_softmax_gemm ...@@ -158,10 +158,13 @@ struct find_ck_gemm_scale_bias_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); auto gemm1 =
auto pw = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias"); match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto pw =
match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax)); return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -180,16 +183,19 @@ struct find_ck_gemm_scale_bias_softmax_gemm ...@@ -180,16 +183,19 @@ struct find_ck_gemm_scale_bias_softmax_gemm
auto inputs = gemm1_ins->inputs(); // A, B auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1 inputs.push_back(gemm2_ins->inputs().back()); // B1
//inputs.push_back(pw_ins->inputs().back()); // C // inputs.push_back(pw_ins->inputs().back()); // C
mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs); mpm.get_module().replace_instruction(
ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
} }
// auto matcher() const // auto matcher() const
// { // {
// auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); // auto gemm1 =
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); // match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax)); // auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// } // }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const // void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -207,7 +213,8 @@ struct find_ck_gemm_scale_bias_softmax_gemm ...@@ -207,7 +213,8 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// auto inputs = gemm1_ins->inputs(); // A, B // auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1 // inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs); // mpm.get_module().replace_instruction(ins,
// ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// } // }
}; };
......
...@@ -213,7 +213,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -213,7 +213,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return "ck::Tuple<" + join_strings(s, ",") + ">"; return "ck::Tuple<" + join_strings(s, ",") + ">";
} }
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; } std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -259,14 +262,15 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -259,14 +262,15 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto gemm1_nperblock = 64; auto gemm1_nperblock = 64;
auto gemm01_mperblock = 256; auto gemm01_mperblock = 256;
auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) * int_div_ceil(n, gemm1_nperblock);//ip.get_grid_size(config); auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) *
int_div_ceil(n, gemm1_nperblock); // ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2, auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(), c_shape.lens().rend(),
std::size_t{1}, std::size_t{1},
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
hip_compile_options options; hip_compile_options options;
auto block_size = 256; //ip.get_block_size(); auto block_size = 256; // ip.get_block_size();
auto grid_size = batch_count * blocks_per_batch; auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
...@@ -278,7 +282,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -278,7 +282,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1"; options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel, auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"instance", ""/* ip.str() */}, {{"instance", "" /* ip.str() */},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
...@@ -296,7 +300,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -296,7 +300,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") + v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"; "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>"; v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
} }
...@@ -306,7 +311,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -306,7 +311,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
if(enabled(MIGRAPHX_LOG_CK_GEMM{})) if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{ {
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()}; std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl; std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes))
<< std::endl;
} }
}); });
} }
......
...@@ -54,14 +54,15 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -54,14 +54,15 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr const auto m = a_shape.lens[0]; constexpr const auto m = a_shape.lens[0];
constexpr const auto k = a_shape.lens[1]; constexpr const auto k = a_shape.lens[1];
constexpr const auto sa = a_shape.strides[0]; constexpr const auto sa = a_shape.strides[0];
constexpr const auto a_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(m, k), constexpr const auto a_tensor =
ck::make_tuple(sa, 1)); ck::make_naive_tensor_descriptor(ck::make_tuple(m, k), ck::make_tuple(sa, 1));
constexpr const auto a_grid_desc_mraw_kraw = gemm.matrix_padder.PadADescriptor_M_K(a_tensor); constexpr const auto a_grid_desc_mraw_kraw = gemm.matrix_padder.PadADescriptor_M_K(a_tensor);
constexpr const auto AK1 = gemm.get_AK1(); constexpr const auto AK1 = gemm.get_AK1();
constexpr const auto AK0 = k / AK1; constexpr const auto AK0 = k / AK1;
constexpr const auto a_grid_desc_ak0_m_ak1 = ck::transform_tensor_descriptor(a_grid_desc_mraw_kraw, constexpr const auto a_grid_desc_ak0_m_ak1 = ck::transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(AK0, AK1)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(m)), ck::make_pass_through_transform(m)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
...@@ -73,10 +74,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -73,10 +74,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr const auto BK1 = gemm.get_BK1(); constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1; constexpr const auto BK0 = k / BK1;
constexpr const auto b_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(n, k), constexpr const auto b_tensor =
ck::make_tuple(sb, 1)); ck::make_naive_tensor_descriptor(ck::make_tuple(n, k), ck::make_tuple(sb, 1));
constexpr const auto b_grid_desc_nraw_kraw = gemm.matrix_padder.PadBDescriptor_N_K(b_tensor); constexpr const auto b_grid_desc_nraw_kraw = gemm.matrix_padder.PadBDescriptor_N_K(b_tensor);
constexpr const auto b_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(b_grid_desc_nraw_kraw, constexpr const auto b_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(BK0, BK1)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(n)), ck::make_pass_through_transform(n)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
...@@ -89,10 +91,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -89,10 +91,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr const auto B1K1 = gemm.get_B1K1(); constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1; constexpr const auto B1K0 = k1 / B1K1;
constexpr const auto b1_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(n1, k1), constexpr const auto b1_tensor =
ck::make_tuple(1, sb1)); ck::make_naive_tensor_descriptor(ck::make_tuple(n1, k1), ck::make_tuple(1, sb1));
constexpr const auto b1_grid_desc_nraw_kraw = gemm.matrix_padder.PadB1Descriptor_N_K(b1_tensor); constexpr const auto b1_grid_desc_nraw_kraw = gemm.matrix_padder.PadB1Descriptor_N_K(b1_tensor);
constexpr const auto b1_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(b1_grid_desc_nraw_kraw, constexpr const auto b1_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(
b1_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(B1K0, B1K1)), ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(B1K0, B1K1)),
ck::make_pass_through_transform(n1)), ck::make_pass_through_transform(n1)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}), ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
...@@ -100,11 +103,10 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -100,11 +103,10 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr const auto c_shape = get_shape_c<C>{}; constexpr const auto c_shape = get_shape_c<C>{};
constexpr const auto sc = c_shape.strides[0]; constexpr const auto sc = c_shape.strides[0];
constexpr const auto c_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(m, n1), constexpr const auto c_tensor =
ck::make_tuple(sc, 1)); ck::make_naive_tensor_descriptor(ck::make_tuple(m, n1), ck::make_tuple(sc, 1));
constexpr const auto c_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(c_tensor); constexpr const auto c_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(c_tensor);
constexpr const auto MPerBlock = gemm.get_mperblock(); constexpr const auto MPerBlock = gemm.get_mperblock();
constexpr const auto Gemm1NPerBlock = gemm.get_gemm1nperblock(); constexpr const auto Gemm1NPerBlock = gemm.get_gemm1nperblock();
constexpr const auto MBlock = m / MPerBlock; constexpr const auto MBlock = m / MPerBlock;
...@@ -112,18 +114,20 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -112,18 +114,20 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock = constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
ck::transform_tensor_descriptor( ck::transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(MBlock, ck::Number<MPerBlock>{})), ck::make_tuple(
ck::make_unmerge_transform(ck::make_tuple(MBlock, ck::Number<MPerBlock>{})),
ck::make_unmerge_transform(ck::make_tuple(NBlock, ck::Number<Gemm1NPerBlock>{}))), ck::make_unmerge_transform(ck::make_tuple(NBlock, ck::Number<Gemm1NPerBlock>{}))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}), ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{})); ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}));
constexpr const auto block_2_ctile_map = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>( constexpr const auto block_2_ctile_map =
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n); c_grid_desc_m_n);
const C0MatrixMask c0_matrix_mask(n); const C0MatrixMask c0_matrix_mask(n);
const auto K = const auto K = a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}); a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{});
using gridwise = typename G::template rt_gridwisegemm<decltype(a_grid_desc_ak0_m_ak1), using gridwise = typename G::template rt_gridwisegemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1), decltype(b_grid_desc_bk0_n_bk1),
...@@ -135,7 +139,6 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1) ...@@ -135,7 +139,6 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
......
...@@ -121,10 +121,7 @@ struct C0MatrixMask ...@@ -121,10 +121,7 @@ struct C0MatrixMask
__device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; } __device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
__device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const __device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const { return n >= NRaw_; }
{
return n >= NRaw_;
}
__device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const __device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
{ {
...@@ -197,8 +194,8 @@ template <typename ALayout, ...@@ -197,8 +194,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::LoopScheduler::Default> ck::LoopScheduler LoopSched = ck::LoopScheduler::Default>
struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
static constexpr auto matrix_padder = static constexpr auto matrix_padder = ck::tensor_operation::device::
ck::tensor_operation::device::GemmGemmPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t, ck::index_t>{ GemmGemmPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t, ck::index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static constexpr auto get_AK1() { return AK1; }; static constexpr auto get_AK1() { return AK1; };
...@@ -215,7 +212,7 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -215,7 +212,7 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{}; CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha}; AccElementwiseOperation acc_element_op{alpha};
template<typename AGridDesc_AK0_M_AK1, template <typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N> typename CGridDesc_M_N>
...@@ -286,6 +283,5 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -286,6 +283,5 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
}; };
}; };
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -31,6 +31,7 @@ from onnx import TensorProto ...@@ -31,6 +31,7 @@ from onnx import TensorProto
def onnx_test(op_test): def onnx_test(op_test):
def run_test(): def run_test():
op_info = op_test() op_info = op_test()
if len(op_info) > 3: if len(op_info) > 3:
...@@ -1995,12 +1996,13 @@ def gemm_softmax_gemm_test(): ...@@ -1995,12 +1996,13 @@ def gemm_softmax_gemm_test():
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16, [1, 1]) bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16, [1, 1])
out = helper.make_tensor_value_info('out', TensorProto.FLOAT16, [1, 1]) out = helper.make_tensor_value_info('out', TensorProto.FLOAT16, [1, 1])
scale_array = np.array([(1/8)]) scale_array = np.array([(1 / 8)])
scale_tensor = helper.make_tensor(name='scale', scale_tensor = helper.make_tensor(name='scale',
data_type=TensorProto.FLOAT16, data_type=TensorProto.FLOAT16,
dims=scale_array.shape, dims=scale_array.shape,
vals=scale_array.flatten().astype(np.float16)) vals=scale_array.flatten().astype(
np.float16))
gemm1 = onnx.helper.make_node('MatMul', gemm1 = onnx.helper.make_node('MatMul',
inputs=['a', 'b'], inputs=['a', 'b'],
...@@ -2018,8 +2020,8 @@ def gemm_softmax_gemm_test(): ...@@ -2018,8 +2020,8 @@ def gemm_softmax_gemm_test():
inputs=['softmax_out', 'b1'], inputs=['softmax_out', 'b1'],
outputs=['out']) outputs=['out'])
return ([gemm1, mul1, add1, softmax, gemm2], [a, b, c, b1,
return ([gemm1, mul1, add1, softmax, gemm2], [a, b, c, b1, bias], [out], [scale_tensor]) bias], [out], [scale_tensor])
@onnx_test @onnx_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment