Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3ea9fe4c
Commit
3ea9fe4c
authored
May 20, 2022
by
turneram
Browse files
Add attention, layernorm op, transposectx, and transposeqkv
parent
4a312201
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
874 additions
and
0 deletions
+874
-0
src/CMakeLists.txt
src/CMakeLists.txt
+3
-0
src/include/migraphx/op/layernorm.hpp
src/include/migraphx/op/layernorm.hpp
+124
-0
src/include/migraphx/op/transposectx.hpp
src/include/migraphx/op/transposectx.hpp
+71
-0
src/include/migraphx/op/transposeqkv.hpp
src/include/migraphx/op/transposeqkv.hpp
+76
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+168
-0
src/onnx/parse_layernorm.cpp
src/onnx/parse_layernorm.cpp
+46
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-0
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
+40
-0
src/targets/gpu/jit/bert_transpose.cpp
src/targets/gpu/jit/bert_transpose.cpp
+99
-0
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
+36
-0
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
+43
-0
src/targets/gpu/layernorm.cpp
src/targets/gpu/layernorm.cpp
+24
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+3
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+16
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+25
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+44
-0
test/verify/0layernorm_test.cpp
test/verify/0layernorm_test.cpp
+17
-0
test/verify/0transposectx_test.cpp
test/verify/0transposectx_test.cpp
+17
-0
test/verify/0transposeqkv_test.cpp
test/verify/0transposeqkv_test.cpp
+17
-0
No files found.
src/CMakeLists.txt
View file @
3ea9fe4c
...
@@ -117,6 +117,7 @@ register_migraphx_ops(
...
@@ -117,6 +117,7 @@ register_migraphx_ops(
if_op
if_op
im2col
im2col
isnan
isnan
layernorm
leaky_relu
leaky_relu
less
less
load
load
...
@@ -184,6 +185,8 @@ register_migraphx_ops(
...
@@ -184,6 +185,8 @@ register_migraphx_ops(
tan
tan
topk
topk
transpose
transpose
transposectx
transposeqkv
unary_not
unary_not
undefined
undefined
unknown
unknown
...
...
src/include/migraphx/op/layernorm.hpp
0 → 100644
View file @
3ea9fe4c
#ifndef MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#define MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
layernorm
{
float
epsilon
=
1e-3
;
int64_t
axis
=
-
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
axis
,
"axis"
));
}
value
attributes
()
const
{
value
normalize
;
normalize
[
"axis"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"layernorm"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
==
2
)
{
if
(
inputs
.
at
(
1
).
lens
().
front
()
!=
inputs
.
front
().
lens
().
at
(
axis
))
MIGRAPHX_THROW
(
"LAYERNORM: weights have wrong shape"
);
}
if
(
inputs
.
size
()
==
3
)
{
if
(
inputs
.
at
(
2
).
lens
().
front
()
!=
inputs
.
front
().
lens
().
at
(
axis
))
MIGRAPHX_THROW
(
"LAYERNORM: bias has wrong shape"
);
}
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
x_lens
=
args
.
front
().
get_shape
().
lens
();
auto
norm_count
=
std
::
accumulate
(
x_lens
.
begin
(),
x_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
norm_size
=
std
::
accumulate
(
x_lens
.
begin
()
+
axis
,
x_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
args
.
size
()
==
3
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
],
args
[
2
])(
[
&
](
auto
output
,
auto
data
,
auto
weights
,
auto
bias
)
{
par_for
(
norm_count
,
[
&
](
auto
idx
)
{
auto
offset
=
idx
*
norm_size
;
double
mean
=
0
;
double
mean_square
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
mean
+=
data
[
offset
+
i
];
mean_square
+=
data
[
offset
+
i
]
*
data
[
offset
+
i
];
}
mean
/=
norm_size
;
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
if
(
args
.
size
()
==
3
)
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
*
weights
[
i
]
+
bias
[
i
];
else
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
*
weights
[
i
];
}
});
});
}
else
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
par_for
(
norm_count
,
[
&
](
auto
idx
)
{
auto
offset
=
idx
*
norm_size
;
double
mean
=
0
;
double
mean_square
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
mean
+=
data
[
offset
+
i
];
mean_square
+=
data
[
offset
+
i
]
*
data
[
offset
+
i
];
}
mean
/=
norm_size
;
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
// scale and bias handled by pointwise ops
}
});
});
}
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/transposectx.hpp
0 → 100644
View file @
3ea9fe4c
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
transposectx
{
std
::
string
name
()
const
{
return
"transposectx"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
out_lens
{
lens
[
0
],
lens
[
2
],
lens
[
1
],
lens
[
3
]};
return
{
inputs
.
front
().
type
(),
out_lens
};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
// Input: BxNxSxH
// Output: BxSxNxH
argument
result
{
output_shape
};
auto
in_s
=
args
.
front
().
get_shape
();
auto
lens
=
in_s
.
lens
();
visit_all
(
result
,
args
.
front
())([
&
](
auto
output
,
const
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx
=
in_s
.
multi
(
i
);
int
n
=
idx
.
at
(
1
);
int
s
=
idx
.
at
(
2
);
int
b
=
idx
.
front
();
int
num_heads
=
lens
.
at
(
1
);
int
sequence_length
=
lens
.
at
(
2
);
int
head_size
=
lens
.
back
();
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
const
int
j
=
idx
.
back
();
output
[
out_offset
+
j
]
=
input
[
i
];
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/transposeqkv.hpp
0 → 100644
View file @
3ea9fe4c
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
transposeqkv
{
std
::
string
name
()
const
{
return
"transposeqkv"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
auto
lens
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
out_lens
{
lens
[
2
],
lens
[
0
],
lens
[
3
],
lens
[
1
],
lens
[
4
]};
return
{
inputs
.
front
().
type
(),
out_lens
};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
// Input: BxSxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto
in_s
=
args
.
front
().
get_shape
();
auto
lens
=
in_s
.
lens
();
argument
result
{
output_shape
};
visit_all
(
result
,
args
.
front
())([
&
](
auto
output
,
const
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx
=
in_s
.
multi
(
i
);
const
int
b
=
idx
.
front
();
const
int
s
=
idx
.
at
(
1
);
const
int
m
=
idx
.
at
(
2
);
const
int
n
=
idx
.
at
(
3
);
const
int
j
=
idx
.
back
();
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
batch_size
=
lens
[
0
];
const
int
H
=
lens
.
back
();
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
output
[
out_offset
+
j
]
=
input
[
i
];
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
3ea9fe4c
...
@@ -43,6 +43,7 @@
...
@@ -43,6 +43,7 @@
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/load.hpp>
...
@@ -108,6 +109,8 @@
...
@@ -108,6 +109,8 @@
#include <migraphx/op/tan.hpp>
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/transposectx.hpp>
#include <migraphx/op/transposeqkv.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/undefined.hpp>
...
...
src/onnx/parse_attention.cpp
0 → 100644
View file @
3ea9fe4c
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_attention
:
op_parser
<
parse_attention
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Attention"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
auto
input
=
args
[
0
];
auto
weights
=
args
[
1
];
auto
bias
=
args
[
2
];
auto
mask_index
=
args
[
3
];
instruction_ref
past
;
instruction_ref
extra_add_qk
;
bool
is_past
=
false
;
bool
is_extra_add_qk
=
false
;
if
(
args
.
size
()
>
4
)
{
past
=
args
[
4
];
is_past
=
true
;
}
if
(
args
.
size
()
==
6
)
{
is_extra_add_qk
=
true
;
extra_add_qk
=
args
[
5
];
}
// ORT default is 12
std
::
size_t
num_heads
=
12
;
if
(
contains
(
info
.
attributes
,
"num_heads"
))
num_heads
=
info
.
attributes
.
at
(
"num_heads"
).
i
();
// input shape: (batch_size, sequence_length, input_hidden_size)
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
batch_size
=
input_lens
.
at
(
0
);
auto
sequence_length
=
input_lens
.
at
(
1
);
auto
input_hidden_size
=
input_lens
.
at
(
2
);
// bias shape: (3 * hidden_size)
auto
bias_lens
=
bias
->
get_shape
().
lens
();
auto
hidden_size
=
bias_lens
.
at
(
0
)
/
3
;
auto
head_size
=
hidden_size
/
num_heads
;
int
past_sequence_length
=
0
;
// GetPresent
// Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std
::
vector
<
std
::
size_t
>
present_lens
{
2
,
batch_size
,
num_heads
,
sequence_length
,
head_size
};
if
(
is_past
)
{
auto
past_lens
=
past
->
get_shape
().
lens
();
past_sequence_length
=
past_lens
.
at
(
3
);
present_lens
[
3
]
+=
past_lens
[
3
];
}
// Use GEMM for fully connection.
auto
m
=
batch_size
*
sequence_length
;
auto
n
=
bias_lens
.
front
();
auto
k
=
input_hidden_size
;
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto
bias_type
=
bias
->
get_shape
().
type
();
std
::
vector
<
float
>
ones_vec
(
m
,
1
);
std
::
vector
<
std
::
size_t
>
ones_lens
{
1
,
m
};
auto
ones
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
);
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_sq
,
weights
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
// transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH
auto
batches
=
batch_size
*
num_heads
;
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
total_size
=
batches
*
size_per_batch
;
auto
q_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transqkv
);
auto
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transqkv
);
auto
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transqkv
);
q_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
q_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
k_t
);
v_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
v_t
);
if
(
is_past
)
{
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
3
}}),
past
,
k_t
);
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
k_t
);
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
// sequence length.
auto
mask_index_lens
=
mask_index
->
get_shape
().
lens
();
bool
use_raw_attention_mask
=
mask_index_lens
.
size
()
>=
2
;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const
float
rsqrt_head_size
=
1.
f
/
sqrt
(
static_cast
<
float
>
(
head_size
));
const
int
all_sequence_length
=
past_sequence_length
+
sequence_length
;
const
int
temp_matrix_size
=
sequence_length
*
all_sequence_length
;
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const
float
alpha
=
use_raw_attention_mask
?
1.0
:
rsqrt_head_size
;
// K{B,N,S,H} -> K'{B,N,H,S}
k_t
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
k_t
);
auto
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
q_t
,
k_t
);
if
(
is_extra_add_qk
)
gemm3
=
info
.
add_instruction
(
make_op
(
"add"
),
gemm3
,
extra_add_qk
);
auto
alpha_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm3
->
get_shape
().
lens
()}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
gemm3
->
get_shape
().
type
()},
{
alpha
}}));
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm3
,
info
.
make_contiguous
(
alpha_lit
));
// apply softmax and store result P to scratch2: BxNxSxS*
// Inference mask is all 1s => masking can be skipped
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
// compute P*V
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
gemm4
);
gemm4
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
return
gemm4
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/onnx/parse_layernorm.cpp
0 → 100644
View file @
3ea9fe4c
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_layernorm
:
op_parser
<
parse_layernorm
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LayerNormalization"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
float
epsilon
=
1e-3
f
;
int64_t
axis
=
-
1
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
if
(
contains
(
info
.
attributes
,
"axis"
))
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
}
auto
layernorm
=
info
.
add_instruction
(
make_op
(
"layernorm"
,
{{
"epsilon"
,
epsilon
},
{
"axis"
,
axis
}}),
args
.
front
());
if
(
args
.
size
()
==
3
)
{
layernorm
=
info
.
add_broadcastable_binary_op
(
"mul"
,
layernorm
,
args
.
at
(
1
));
layernorm
=
info
.
add_broadcastable_binary_op
(
"add"
,
layernorm
,
args
.
at
(
2
));
}
return
layernorm
;
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/CMakeLists.txt
View file @
3ea9fe4c
...
@@ -148,6 +148,7 @@ add_library(migraphx_gpu
...
@@ -148,6 +148,7 @@ add_library(migraphx_gpu
int8_conv_pack.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
kernel.cpp
layernorm.cpp
lowering.cpp
lowering.cpp
logsoftmax.cpp
logsoftmax.cpp
loop.cpp
loop.cpp
...
@@ -204,6 +205,7 @@ register_migraphx_gpu_ops(hip_
...
@@ -204,6 +205,7 @@ register_migraphx_gpu_ops(hip_
floor
floor
gather
gather
greater
greater
layernorm
less
less
log
log
logsoftmax
logsoftmax
...
...
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
0 → 100644
View file @
3ea9fe4c
#ifndef MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#include <migraphx/op/layernorm.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/argument.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_layernorm
{
op
::
layernorm
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::layernorm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
);
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/jit/bert_transpose.cpp
0 → 100644
View file @
3ea9fe4c
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
static
const
char
*
const
transposectx_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposectx.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposectx_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transposectx(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
transposectx_compiler
:
compiler
<
transposectx_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"transposectx"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposectx_kernel"
;
return
compile_hip_code_object
(
transposectx_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
static
const
char
*
const
transposeqkv_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposeqkv.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposeqkv_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transposeqkv(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
transposeqkv_compiler
:
compiler
<
transposeqkv_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"transposeqkv"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposeqkv_kernel"
;
return
compile_hip_code_object
(
transposeqkv_kernel
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
0 → 100644
View file @
3ea9fe4c
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace
migraphx
{
template
<
class
T
,
class
U
>
__device__
void
transposectx
(
const
T
&
input_t
,
const
U
&
output_t
)
{
// Input: BxNxSxH
// Output: BxSxNxH
auto
index
=
make_index
();
auto
input_shape
=
input_t
.
get_shape
();
auto
lens
=
input_shape
.
lens
;
const
int
num_heads
=
lens
[
1
];
const
int
sequence_length
=
lens
[
2
];
int
head_size
=
lens
[
3
];
auto
idx
=
input_shape
.
multi
(
index
.
global
);
int
n
=
idx
[
1
];
int
s
=
idx
[
2
];
int
b
=
idx
[
0
];
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
if
(
index
.
global
<
input_shape
.
elements
())
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
0 → 100644
View file @
3ea9fe4c
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace
migraphx
{
template
<
class
T
,
class
U
>
__device__
void
transposeqkv
(
const
T
&
input_t
,
const
U
&
output_t
)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto
index
=
make_index
();
auto
input_shape
=
input_t
.
get_shape
();
auto
lens
=
input_shape
.
lens
;
auto
idx
=
input_shape
.
multi
(
index
.
global
);
const
int
b
=
idx
[
0
];
const
int
s
=
idx
[
1
];
const
int
m
=
idx
[
2
];
const
int
n
=
idx
[
3
];
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
batch_size
=
lens
[
0
];
const
int
H
=
lens
[
4
];
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
if
(
index
.
global
<
input_shape
.
elements
())
{
output_t
[
out_offset
+
idx
[
4
]]
=
input_t
[
index
.
global
];
}
}
}
// namespace migraphx
#endif
src/targets/gpu/layernorm.cpp
0 → 100644
View file @
3ea9fe4c
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_layernorm
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
hip_layernorm
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
layernorm
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]);
return
args
.
back
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/lowering.cpp
View file @
3ea9fe4c
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
...
@@ -29,6 +30,7 @@
...
@@ -29,6 +30,7 @@
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/logical_and.hpp>
#include <migraphx/gpu/logical_and.hpp>
...
@@ -139,6 +141,7 @@ struct miopen_apply
...
@@ -139,6 +141,7 @@ struct miopen_apply
add_generic_op
(
"exp"
);
add_generic_op
(
"exp"
);
add_generic_op
(
"floor"
);
add_generic_op
(
"floor"
);
add_generic_op
(
"greater"
);
add_generic_op
(
"greater"
);
add_generic_op
(
"layernorm"
);
add_generic_op
(
"less"
);
add_generic_op
(
"less"
);
add_generic_op
(
"log"
);
add_generic_op
(
"log"
);
add_generic_op
(
"logical_and"
);
add_generic_op
(
"logical_and"
);
...
...
test/onnx/gen_onnx.py
View file @
3ea9fe4c
...
@@ -2622,6 +2622,22 @@ def layernorm_test():
...
@@ -2622,6 +2622,22 @@ def layernorm_test():
bias_add
],
[
x
,
scale
,
bias
],
[
y
],
[
pow_tensor
,
epsilon_tensor
])
bias_add
],
[
x
,
scale
,
bias
],
[
y
],
[
pow_tensor
,
epsilon_tensor
])
@
onnx_test
def
layernorm_op_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
1
,
2
,
3
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
3
])
b
=
helper
.
make_tensor_value_info
(
'b'
,
TensorProto
.
FLOAT
,
[
3
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
1
,
2
,
3
])
node
=
onnx
.
helper
.
make_node
(
'LayerNormalization'
,
inputs
=
[
'x'
,
'w'
,
'b'
],
outputs
=
[
"output"
],
epsilon
=
1e-5
)
return
([
node
],
[
x
,
w
,
b
],
[
output
])
@
onnx_test
@
onnx_test
def
leaky_relu_test
():
def
leaky_relu_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
test/onnx/verify_onnx.cpp
View file @
3ea9fe4c
...
@@ -446,6 +446,31 @@ TEST_CASE(instance_norm_3d_test)
...
@@ -446,6 +446,31 @@ TEST_CASE(instance_norm_3d_test)
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
}
TEST_CASE
(
layernorm_op_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"layernorm_op_test.onnx"
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
swb
{
migraphx
::
shape
::
float_type
,
{
3
}};
std
::
vector
<
float
>
x_vec
{
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
};
std
::
vector
<
float
>
w_vec
{
1.0
,
1.0
,
1.0
};
std
::
vector
<
float
>
b_vec
{
0.0
,
0.0
,
0.0
};
migraphx
::
parameter_map
pp
;
pp
[
"x"
]
=
migraphx
::
argument
(
sx
,
x_vec
.
data
());
pp
[
"w"
]
=
migraphx
::
argument
(
swb
,
w_vec
.
data
());
pp
[
"b"
]
=
migraphx
::
argument
(
swb
,
b_vec
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
result_vector
(
6
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
-
1.22474
f
,
0.0
f
,
1.22474
f
,
-
1.22474
f
,
0.0
f
,
1.22474
f
};
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
lessorequal_test
)
TEST_CASE
(
lessorequal_test
)
{
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"lessorequal_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"lessorequal_test.onnx"
);
...
...
test/ref_ops_test.cpp
View file @
3ea9fe4c
...
@@ -2435,6 +2435,50 @@ TEST_CASE(imagescaler_test)
...
@@ -2435,6 +2435,50 @@ TEST_CASE(imagescaler_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
layernorm_test
)
{
{
// with scale and bias
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
swb
{
migraphx
::
shape
::
float_type
,
{
3
}};
std
::
vector
<
float
>
x_vec
{
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
};
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
{
sx
,
x_vec
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
swb
,
{
1.0
,
1.0
,
1.0
}});
auto
b
=
mm
->
add_literal
(
migraphx
::
literal
{
swb
,
{
0.0
,
0.0
,
0.0
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"layernorm"
,
{{
"epsilon"
,
1e-5
}}),
x
,
w
,
b
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
(
1
*
2
*
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
1.22474
f
,
0.0
f
,
1.22474
f
,
-
1.22474
f
,
0.0
f
,
1.22474
f
};
for
(
auto
&&
i
:
results_vector
)
std
::
cout
<<
i
<<
", "
;
std
::
cout
<<
std
::
endl
;
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// without scale and bias
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
std
::
vector
<
float
>
x_vec
{
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
};
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
{
sx
,
x_vec
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"layernorm"
,
{{
"epsilon"
,
1e-5
}}),
x
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
(
1
*
2
*
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
1.22474
f
,
0.0
f
,
1.22474
f
,
-
1.22474
f
,
0.0
f
,
1.22474
f
};
for
(
auto
&&
i
:
results_vector
)
std
::
cout
<<
i
<<
", "
;
std
::
cout
<<
std
::
endl
;
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
leaky_relu_test
)
TEST_CASE
(
leaky_relu_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/verify/0layernorm_test.cpp
0 → 100644
View file @
3ea9fe4c
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_layernorm
:
verify_program
<
test_layernorm
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
768
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"layernorm"
,
{{
"axis"
,
-
1
}}),
x
);
return
p
;
}
};
\ No newline at end of file
test/verify/0transposectx_test.cpp
0 → 100644
View file @
3ea9fe4c
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_transposectx
:
verify_program
<
test_transposectx
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
12
,
128
,
64
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
x
);
return
p
;
}
};
test/verify/0transposeqkv_test.cpp
0 → 100644
View file @
3ea9fe4c
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_transposeqkv
:
verify_program
<
test_transposeqkv
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
3
,
12
,
64
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
x
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment