Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fe9a42f1
"docs/archive_en_US/Tutorial/Nnictl.md" did not exist on "9d47087ecd6fbc39d3564bfa1a90adff16989039"
Commit
fe9a42f1
authored
May 17, 2022
by
turneram
Browse files
Fix transpose kernels
parent
96663815
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
162 additions
and
151 deletions
+162
-151
src/include/migraphx/op/transposectx.hpp
src/include/migraphx/op/transposectx.hpp
+20
-11
src/include/migraphx/op/transposeqkv.hpp
src/include/migraphx/op/transposeqkv.hpp
+23
-12
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+9
-13
src/targets/gpu/jit/bert_transpose.cpp
src/targets/gpu/jit/bert_transpose.cpp
+4
-32
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
+15
-41
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
+18
-39
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+70
-0
test/verify/0transposectx_test.cpp
test/verify/0transposectx_test.cpp
+2
-2
test/verify/0transposeqkv_test.cpp
test/verify/0transposeqkv_test.cpp
+1
-1
No files found.
src/include/migraphx/op/transposectx.hpp
View file @
fe9a42f1
...
...
@@ -20,15 +20,6 @@ namespace op {
struct
transposectx
{
int
head_size
=
64
;
bool
reversed_bs
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
head_size
,
"head_size"
),
f
(
self
.
reversed_bs
,
"reversed_bs"
));
}
std
::
string
name
()
const
{
return
"transposectx"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
@@ -45,10 +36,28 @@ struct transposectx
// Input: BxNxSxH
// Output: BxSxNxH
argument
result
{
output_shape
};
auto
in_s
=
args
.
front
().
get_shape
();
auto
lens
=
in_s
.
lens
();
visit_all
(
result
,
args
.
front
())([
&
](
auto
output
,
const
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
// TODO: calculate in_offet and out_offset
output
[
i
]
=
input
[
i
];
auto
idx
=
in_s
.
multi
(
i
);
int
n
=
idx
.
at
(
1
);
int
s
=
idx
.
at
(
2
);
int
b
=
idx
.
front
();
int
num_heads
=
lens
.
at
(
1
);
int
sequence_length
=
lens
.
at
(
2
);
int
head_size
=
lens
.
back
();
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
//const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
const
int
j
=
idx
.
back
();
output
[
out_offset
+
j
]
=
input
[
i
];
});
});
...
...
src/include/migraphx/op/transposeqkv.hpp
View file @
fe9a42f1
...
...
@@ -20,15 +20,6 @@ namespace op {
struct
transposeqkv
{
int
head_size
=
64
;
bool
reversed_bs
=
false
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
head_size
,
"head_size"
),
f
(
self
.
reversed_bs
,
"reversed_bs"
));
}
std
::
string
name
()
const
{
return
"transposeqkv"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
@@ -42,14 +33,34 @@ struct transposeqkv
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
// Input: BxSxKxNxH
or SxBxKxNxH
// Input: BxSxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto
in_s
=
args
.
front
().
get_shape
();
auto
lens
=
in_s
.
lens
();
argument
result
{
output_shape
};
visit_all
(
result
,
args
.
front
())([
&
](
auto
output
,
const
auto
input
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
// TODO: calculate in_offet and out_offset
output
[
i
]
=
input
[
i
];
auto
idx
=
in_s
.
multi
(
i
);
const
int
b
=
idx
.
front
();
const
int
s
=
idx
.
at
(
1
);
const
int
m
=
idx
.
at
(
2
);
const
int
n
=
idx
.
at
(
3
);
const
int
j
=
idx
.
back
();
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
batch_size
=
lens
[
0
];
const
int
H
=
lens
.
back
();
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
output
[
out_offset
+
j
]
=
input
[
i
];
});
});
...
...
src/onnx/parse_attention.cpp
View file @
fe9a42f1
...
...
@@ -82,31 +82,28 @@ struct parse_attention : op_parser<parse_attention>
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */
);
ones
);
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x
/// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_sq
,
weights
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
// LaunchAttentionKernel:
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
std
::
vector
<
std
::
size_t
>
qkv_perm
{
2
,
0
,
3
,
1
,
4
};
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
,
{{
"head_size"
,
head_size
}}
),
add_gemms
);
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
//
now scratch3 has Q, K, V: each has size
BxNxSxH
// =>
transqkv has shape 3x
BxNxSxH
//
transqkv has shape 3x
BxNxSxH
// =>
Q, K, V: each has size
BxNxSxH
auto
batches
=
batch_size
*
num_heads
;
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
total_size
=
batches
*
size_per_batch
;
...
...
@@ -158,12 +155,11 @@ struct parse_attention : op_parser<parse_attention>
// Inference mask is all 1s => masking can be skipped
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
// compute P*V
(as V*P), and store in scratch3: BxNxSxH
// compute P*V
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
// scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposectx"
,
{{
"head_size"
,
head_size
}}),
gemm4
);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
gemm4
);
gemm4
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
...
...
src/targets/gpu/jit/bert_transpose.cpp
View file @
fe9a42f1
...
...
@@ -19,17 +19,13 @@ namespace gpu {
static
const
char
*
const
transposectx_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposectx.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposectx_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
auto settings = make_transposectx_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{HEAD_SIZE}), MIGRAPHX_MAKE_CONSTANT(bool{REVERSED_BS}));
transposectx(input, output, settings);
transposectx(input, output);
});
}
...
...
@@ -44,21 +40,11 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
64
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
()
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposectx_kernel"
;
// head_size
assert
(
v
.
contains
(
"head_size"
));
auto
head_size
=
v
.
at
(
"head_size"
).
to
<
int
>
();
options
.
params
+=
" -DHEAD_SIZE="
+
std
::
to_string
(
head_size
);
// reversed_bs
assert
(
v
.
contains
(
"reversed_bs"
));
auto
reversed_bs
=
v
.
at
(
"reversed_bs"
).
to
<
bool
>
();
options
.
params
+=
" -DREVERSED_BS="
+
std
::
to_string
(
reversed_bs
);
return
compile_hip_code_object
(
transposectx_kernel
,
options
);
}
...
...
@@ -71,17 +57,13 @@ struct transposectx_compiler : compiler<transposectx_compiler>
static
const
char
*
const
transposeqkv_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposeqkv.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposeqkv_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
auto settings = make_transposeqkv_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{HEAD_SIZE}), MIGRAPHX_MAKE_CONSTANT(bool{REVERSED_BS}));
transposeqkv(input, output, settings);
transposeqkv(input, output);
});
}
...
...
@@ -96,21 +78,11 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
64
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
()
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposeqkv_kernel"
;
// head_size
assert
(
v
.
contains
(
"head_size"
));
auto
head_size
=
v
.
at
(
"head_size"
).
to
<
int
>
();
options
.
params
+=
" -DHEAD_SIZE="
+
std
::
to_string
(
head_size
);
// reversed_bs
assert
(
v
.
contains
(
"reversed_bs"
));
auto
reversed_bs
=
v
.
at
(
"reversed_bs"
).
to
<
bool
>
();
options
.
params
+=
" -DREVERSED_BS="
+
std
::
to_string
(
reversed_bs
);
return
compile_hip_code_object
(
transposeqkv_kernel
,
options
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
View file @
fe9a42f1
...
...
@@ -7,55 +7,29 @@
namespace
migraphx
{
template
<
class
T
,
class
U
>
struct
transposectx_settings
{
T
head_size
{};
U
reversed_bs
{};
};
template
<
class
...
Ts
>
constexpr
transposectx_settings
<
Ts
...
>
make_transposectx_settings
(
Ts
...
xs
)
{
return
{
xs
...};
}
template
<
class
T
,
class
U
,
class
Settings
>
__device__
void
transposectx
(
const
T
&
input_t
,
const
U
&
output_t
,
Settings
st
)
__device__
void
transposectx
(
const
T
&
input_t
,
const
U
&
output_t
)
{
// Input: BxNxSxH
// Output: BxSxNxH
auto
index
=
make_index
();
auto
input_shape
=
input_t
.
get_shape
();
auto
lens
=
input_shape
.
lens
;
const
int
num_heads
=
lens
[
1
];
const
int
sequence_length
=
lens
[
2
];
int
head_size
=
lens
[
3
];
auto
head_size
=
st
.
head_size
;
auto
reversed_bs
=
st
.
reversed_bs
;
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
auto
idx
=
input_shape
.
multi
(
index
.
global
);
int
num_heads
=
blockDim
.
y
;
int
sequence_length
=
gridDim
.
x
;
int
n
=
idx
[
1
];
int
s
=
idx
[
2
];
int
b
=
idx
[
0
];
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
in_offset
=
s
*
head_size
+
n
*
sequence_length
*
head_size
+
b
*
NHS
;
int
out_offset
=
0
;
if
(
reversed_bs
)
{
const
int
batch_size
=
gridDim
.
y
;
const
int
BNH
=
NH
*
batch_size
;
out_offset
=
n
*
head_size
+
b
*
NH
+
s
*
BNH
;
}
else
{
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
}
const
int
i
=
threadIdx
.
x
;
if
(
i
<
head_size
)
{
output_t
[
out_offset
+
i
]
=
input_t
[
in_offset
+
i
];
}
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
if
(
index
.
local
<
1024
)
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
}
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
View file @
fe9a42f1
...
...
@@ -7,57 +7,36 @@
namespace
migraphx
{
template
<
class
T
,
class
U
>
struct
transposeqkv_settings
{
T
head_size
{};
U
reversed_bs
{};
};
template
<
class
...
Ts
>
constexpr
transposeqkv_settings
<
Ts
...
>
make_transposeqkv_settings
(
Ts
...
xs
)
{
return
{
xs
...};
}
template
<
class
T
,
class
U
,
class
Settings
>
__device__
void
transposeqkv
(
const
T
&
input_t
,
const
U
&
output_t
,
Settings
st
)
__device__
void
transposeqkv
(
const
T
&
input_t
,
const
U
&
output_t
)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto
H
=
st
.
head_size
;
auto
reversed_bs
=
st
.
reversed_bs
;
auto
index
=
make_index
();
auto
input_shape
=
input_t
.
get_shape
();
auto
lens
=
input_shape
.
lens
;
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
auto
idx
=
input_shape
.
multi
(
index
.
global
);
const
int
num_heads
=
blockDim
.
y
;
const
int
b
=
idx
[
0
];
const
int
s
=
idx
[
1
];
const
int
m
=
idx
[
2
];
const
int
n
=
idx
[
3
];
//const int j = idx[4];
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
chunk_num
=
gridDim
.
z
;
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
batch_size
=
lens
[
0
];
const
int
H
=
lens
[
4
];
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
0
;
if
(
reversed_bs
)
{
const
int
BNH
=
NH
*
batch_size
;
in_offset
=
n
*
H
+
(
m
+
b
*
chunk_num
)
*
NH
+
s
*
BNH
*
chunk_num
;
}
else
{
in_offset
=
n
*
H
+
(
m
+
s
*
chunk_num
)
*
NH
+
b
*
NHS
*
chunk_num
;
}
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
i
=
threadIdx
.
x
;
if
(
i
<
H
)
if
(
index
.
global
<
input_shape
.
elements
())
{
output_t
[
out_offset
+
i
]
=
input_t
[
in
_offset
+
i
];
output_t
[
out_offset
+
i
dx
[
4
]
]
=
input_t
[
in
dex
.
global
];
}
}
...
...
test/ref_ops_test.cpp
View file @
fe9a42f1
...
...
@@ -666,6 +666,76 @@ TEST_CASE(batch_norm_inference_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(bert_transpose_ops_test)
{
{
// transposeQKV
migraphx::program p;
auto* mm = p.get_main_module();
const int k = 3, b = 1, n = 2, s = 2, h = 1;
migraphx::shape sh{migraphx::shape::float_type, {b, s, k, n, h}};
std::vector<float> data(b * s * k * n * h);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposeqkv"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(k * b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 4, 5, 8, 9, 2, 3, 6, 7, 10, 11};
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
std::vector<float> result_vector2(k * b * n * s * h);
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
for (auto& i : result_vector2)
std::cout << i << ", ";
std::cout << std::endl;
EXPECT(migraphx::verify_range(result_vector, result_vector2));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 2, 2}};
std::vector<float> data(2 * 2 * 2 * 2);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(2 * 2 * 2 * 2);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15};
EXPECT(migraphx::verify_range(result_vector, gold));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
const int b = 2, n = 2, s = 3, h = 4;
migraphx::shape sh{migraphx::shape::float_type, {b, n, s, h}};
std::vector<float> data(b * n * s * h);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39, 28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47};
EXPECT(migraphx::verify_range(result_vector, gold));
}
}
TEST_CASE(broadcast_test)
{
migraphx::program p;
...
...
test/verify/0transposectx_test.cpp
View file @
fe9a42f1
...
...
@@ -11,8 +11,8 @@ struct test_transposectx : verify_program<test_transposectx>
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
12
,
128
,
64
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
,
{{
"head_size"
,
64
}}
),
x
);
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
6
,
12
,
384
,
64
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
x
);
p
.
debug_print
();
return
p
;
}
...
...
test/verify/0transposeqkv_test.cpp
View file @
fe9a42f1
...
...
@@ -11,7 +11,7 @@ struct test_transposeqkv : verify_program<test_transposeqkv>
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
3
,
12
,
64
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
,
{{
"head_size"
,
64
}}
),
x
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
x
);
p
.
debug_print
();
return
p
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment