Commit 3133fd79 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent d7ea085c
......@@ -74,12 +74,12 @@ struct ck_gemm_scale_bias_softmax_gemm
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[2];
for(const auto& input : inputs)
{
//std::cout << input << std::endl;
// std::cout << input << std::endl;
check_gemm_shape(input);
}
return op.compute_shape({op.compute_shape({a, b}), b1});
......@@ -158,19 +158,22 @@ struct find_ck_gemm_scale_bias_softmax_gemm
{
auto matcher() const
{
auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto pw = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto pw =
match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
std::cout << "Matched" << std::endl;
auto ins = r.result;
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto sm_ins = r.instructions["softmax"];
auto pw_ins = r.instructions["scale_bias"];
auto sm_ins = r.instructions["softmax"];
auto pw_ins = r.instructions["scale_bias"];
auto gemm1_ins = r.instructions["gemm1"];
gemm2_ins->debug_print();
......@@ -178,18 +181,21 @@ struct find_ck_gemm_scale_bias_softmax_gemm
pw_ins->debug_print();
gemm1_ins->debug_print();
auto inputs = gemm1_ins->inputs(); // A, B
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
//inputs.push_back(pw_ins->inputs().back()); // C
// inputs.push_back(pw_ins->inputs().back()); // C
mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
mpm.get_module().replace_instruction(
ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
}
// auto matcher() const
// {
// auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -207,7 +213,8 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// mpm.get_module().replace_instruction(ins,
// ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
};
......
......@@ -213,16 +213,19 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return "ck::Tuple<" + join_strings(s, ",") + ">";
}
std::vector<std::string> names() const { return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"}; }
std::vector<std::string> names() const
{
return {"ck_gemm_softmax_gemm", "gpu::ck_gemm_softmax_gemm"};
}
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
auto a_shape = inputs[0];
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1];
auto n = c_shape.lens()[1];
auto m = a_shape.lens()[0];
auto k = a_shape.lens()[1];
auto n = c_shape.lens()[1];
auto rank = a_shape.lens().size();
......@@ -257,16 +260,17 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto gemm1_nperblock = 64;
auto gemm1_nperblock = 64;
auto gemm01_mperblock = 256;
auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) * int_div_ceil(n, gemm1_nperblock);//ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
auto blocks_per_batch = int_div_ceil(m, gemm01_mperblock) *
int_div_ceil(n, gemm1_nperblock); // ip.get_grid_size(config);
auto batch_count = std::accumulate(c_shape.lens().rbegin() + 2,
c_shape.lens().rend(),
std::size_t{1},
std::multiplies<std::size_t>());
hip_compile_options options;
auto block_size = 256; //ip.get_block_size();
auto block_size = 256; // ip.get_block_size();
auto grid_size = batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
......@@ -278,7 +282,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_softmax_gemm_kernel,
{{"instance", ""/* ip.str() */},
{{"instance", "" /* ip.str() */},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
......@@ -296,7 +300,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_softmax_gemm_function") +
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);";
"\nMIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm_softmax_gemm>";
v["kernel"] = "ck_gemm_softmax_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
......@@ -306,7 +311,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
std::cout << "ck_gemm_softmax_gemm: " << to_json_string(to_value(gemm_shapes))
<< std::endl;
}
});
}
......
......@@ -116,7 +116,7 @@ struct ck_scale
template <class T, class U>
constexpr void operator()(T& y, U x) const
{
y = x * static_cast<U>(scale);
y = x * static_cast<U>(scale);
}
float scale;
......
......@@ -51,96 +51,99 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr const G gemm{};
constexpr const auto a_shape = get_shape_c<A>{};
constexpr const auto m = a_shape.lens[0];
constexpr const auto k = a_shape.lens[1];
constexpr const auto sa = a_shape.strides[0];
constexpr const auto a_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(m, k),
ck::make_tuple(sa, 1));
constexpr const auto m = a_shape.lens[0];
constexpr const auto k = a_shape.lens[1];
constexpr const auto sa = a_shape.strides[0];
constexpr const auto a_tensor =
ck::make_naive_tensor_descriptor(ck::make_tuple(m, k), ck::make_tuple(sa, 1));
constexpr const auto a_grid_desc_mraw_kraw = gemm.matrix_padder.PadADescriptor_M_K(a_tensor);
constexpr const auto AK1 = gemm.get_AK1();
constexpr const auto AK0 = k / AK1;
constexpr const auto a_grid_desc_ak0_m_ak1 = ck::transform_tensor_descriptor(a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(m)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto a_grid_desc_ak0_m_ak1 = ck::transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(AK0, AK1)),
ck::make_pass_through_transform(m)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_shape = get_shape_c<B>{};
constexpr const auto n = b_shape.lens[0]; // col-major
constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1;
constexpr const auto n = b_shape.lens[0]; // col-major
constexpr const auto sb = b_shape.strides[1]; // col-major
constexpr const auto BK1 = gemm.get_BK1();
constexpr const auto BK0 = k / BK1;
constexpr const auto b_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(n, k),
ck::make_tuple(sb, 1));
constexpr const auto b_tensor =
ck::make_naive_tensor_descriptor(ck::make_tuple(n, k), ck::make_tuple(sb, 1));
constexpr const auto b_grid_desc_nraw_kraw = gemm.matrix_padder.PadBDescriptor_N_K(b_tensor);
constexpr const auto b_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(n)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(BK0, BK1)),
ck::make_pass_through_transform(n)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_shape = get_shape_c<B1>{};
constexpr const auto k1 = b1_shape.lens[0]; // row-major
constexpr const auto n1 = b1_shape.lens[1]; // row-major
constexpr const auto sb1 = b1_shape.strides[0]; // rowl-major
constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1;
constexpr const auto b1_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(n1, k1),
ck::make_tuple(1, sb1));
constexpr const auto k1 = b1_shape.lens[0]; // row-major
constexpr const auto n1 = b1_shape.lens[1]; // row-major
constexpr const auto sb1 = b1_shape.strides[0]; // rowl-major
constexpr const auto B1K1 = gemm.get_B1K1();
constexpr const auto B1K0 = k1 / B1K1;
constexpr const auto b1_tensor =
ck::make_naive_tensor_descriptor(ck::make_tuple(n1, k1), ck::make_tuple(1, sb1));
constexpr const auto b1_grid_desc_nraw_kraw = gemm.matrix_padder.PadB1Descriptor_N_K(b1_tensor);
constexpr const auto b1_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(B1K0, B1K1)),
ck::make_pass_through_transform(n1)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto b1_grid_desc_bk0_n_bk1 = ck::transform_tensor_descriptor(
b1_grid_desc_nraw_kraw,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(B1K0, B1K1)),
ck::make_pass_through_transform(n1)),
ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
constexpr const auto c_shape = get_shape_c<C>{};
constexpr const auto sc = c_shape.strides[0];
constexpr const auto c_tensor = ck::make_naive_tensor_descriptor(ck::make_tuple(m, n1),
ck::make_tuple(sc, 1));
constexpr const auto sc = c_shape.strides[0];
constexpr const auto c_tensor =
ck::make_naive_tensor_descriptor(ck::make_tuple(m, n1), ck::make_tuple(sc, 1));
constexpr const auto c_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(c_tensor);
constexpr const auto MPerBlock = gemm.get_mperblock();
constexpr const auto MPerBlock = gemm.get_mperblock();
constexpr const auto Gemm1NPerBlock = gemm.get_gemm1nperblock();
constexpr const auto MBlock = m / MPerBlock;
constexpr const auto NBlock = n1 / Gemm1NPerBlock;
constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
ck::transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(MBlock, ck::Number<MPerBlock>{})),
ck::make_unmerge_transform(ck::make_tuple(NBlock, ck::Number<Gemm1NPerBlock>{}))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}));
constexpr const auto block_2_ctile_map = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
constexpr const auto MBlock = m / MPerBlock;
constexpr const auto NBlock = n1 / Gemm1NPerBlock;
constexpr const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
ck::transform_tensor_descriptor(
c_grid_desc_m_n,
ck::make_tuple(
ck::make_unmerge_transform(ck::make_tuple(MBlock, ck::Number<MPerBlock>{})),
ck::make_unmerge_transform(ck::make_tuple(NBlock, ck::Number<Gemm1NPerBlock>{}))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}),
ck::make_tuple(ck::Sequence<0, 1>{}, ck::Sequence<2, 3>{}));
constexpr const auto block_2_ctile_map =
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, decltype(c_grid_desc_m_n)>(
c_grid_desc_m_n);
const C0MatrixMask c0_matrix_mask(n);
const auto K =
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{});
const auto K = a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{});
using gridwise = typename G::template rt_gridwisegemm<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b1_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_m_n)>;
using GridwiseGemm = typename gridwise::GridwiseGemm;
decltype(b_grid_desc_bk0_n_bk1),
decltype(b1_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_m_n)>;
using GridwiseGemm = typename gridwise::GridwiseGemm;
constexpr const bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map));
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map));
GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
to_ck_const_pointer(b.data()),
to_ck_const_pointer(b1.data()),
......
......@@ -121,10 +121,7 @@ struct C0MatrixMask
__device__ bool IsUpperTriangle(ck::index_t m, ck::index_t n) const { return n > m; }
__device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const
{
return n >= NRaw_;
}
__device__ bool IsNOutOfBound(/*ck::index_t m, */ ck::index_t n) const { return n >= NRaw_; }
__device__ bool IsMaskedElement(ck::index_t m, ck::index_t n) const
{
......@@ -197,8 +194,8 @@ template <typename ALayout,
ck::LoopScheduler LoopSched = ck::LoopScheduler::Default>
struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static constexpr auto matrix_padder =
ck::tensor_operation::device::GemmGemmPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t, ck::index_t>{
static constexpr auto matrix_padder = ck::tensor_operation::device::
GemmGemmPadder<GemmSpec, ck::index_t, ck::index_t, ck::index_t, ck::index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
static constexpr auto get_AK1() { return AK1; };
......@@ -215,11 +212,11 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op{};
AccElementwiseOperation acc_element_op{alpha};
template<typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N>
struct rt_gridwisegemm
template <typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N>
struct rt_gridwisegemm
{
// GridwiseGemm
using GridwiseGemm = ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
......@@ -286,6 +283,5 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
};
};
} // namespace migraphx
#endif
......@@ -31,6 +31,7 @@ from onnx import TensorProto
def onnx_test(op_test):
def run_test():
op_info = op_test()
if len(op_info) > 3:
......@@ -1995,16 +1996,17 @@ def gemm_softmax_gemm_test():
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16, [1, 1])
out = helper.make_tensor_value_info('out', TensorProto.FLOAT16, [1, 1])
scale_array = np.array([(1/8)])
scale_array = np.array([(1 / 8)])
scale_tensor = helper.make_tensor(name='scale',
data_type=TensorProto.FLOAT16,
dims=scale_array.shape,
vals=scale_array.flatten().astype(np.float16))
data_type=TensorProto.FLOAT16,
dims=scale_array.shape,
vals=scale_array.flatten().astype(
np.float16))
gemm1 = onnx.helper.make_node('MatMul',
inputs=['a', 'b'],
outputs=['gemm1_out'])
inputs=['a', 'b'],
outputs=['gemm1_out'])
mul1 = onnx.helper.make_node('Mul',
inputs=['gemm1_out', 'scale'],
outputs=['mul1_out'])
......@@ -2012,14 +2014,14 @@ def gemm_softmax_gemm_test():
inputs=['mul1_out', 'c'],
outputs=['add1_out'])
softmax = onnx.helper.make_node('Softmax',
inputs=['add1_out'],
outputs=['softmax_out'])
inputs=['add1_out'],
outputs=['softmax_out'])
gemm2 = onnx.helper.make_node('MatMul',
inputs=['softmax_out', 'b1'],
outputs=['out'])
inputs=['softmax_out', 'b1'],
outputs=['out'])
return ([gemm1, mul1, add1, softmax, gemm2], [a, b, c, b1, bias], [out], [scale_tensor])
return ([gemm1, mul1, add1, softmax, gemm2], [a, b, c, b1,
bias], [out], [scale_tensor])
@onnx_test
......
......@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = 1 * 12 * 256 * 256;
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto c = mm->add_parameter("4", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
......@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment