Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7757cfd0
Commit
7757cfd0
authored
May 20, 2022
by
turneram
Browse files
Remove non-inference portions of parse_attention
parent
5a62e9e7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
84 deletions
+27
-84
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+27
-84
No files found.
src/onnx/parse_attention.cpp
View file @
7757cfd0
...
...
@@ -11,32 +11,18 @@ struct parse_attention : op_parser<parse_attention>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Attention"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
instruction_ref
parse
(
const
op_desc
&
/*
opd
*/
,
const
onnx_parser
&
/*
parser
*/
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
auto
input
=
args
[
0
];
auto
weights
=
args
[
1
];
auto
bias
=
args
[
2
];
auto
mask_index
=
args
[
3
];
// mask_index = args[3];
// Raw attention mask is 2d (BxS) and all 1s for BERT-base and BERT-large inference
instruction_ref
past
;
instruction_ref
extra_add_qk
;
bool
is_past
=
false
;
bool
is_extra_add_qk
=
false
;
if
(
args
.
size
()
>
4
)
{
past
=
args
[
4
];
is_past
=
true
;
}
if
(
args
.
size
()
==
6
)
{
is_extra_add_qk
=
true
;
extra_add_qk
=
args
[
5
];
}
// ORT default is 12
// BERT-base default is 12, BERT-large default is 16
std
::
size_t
num_heads
=
12
;
if
(
contains
(
info
.
attributes
,
"num_heads"
))
num_heads
=
info
.
attributes
.
at
(
"num_heads"
).
i
();
...
...
@@ -47,30 +33,14 @@ struct parse_attention : op_parser<parse_attention>
auto
sequence_length
=
input_lens
.
at
(
1
);
auto
input_hidden_size
=
input_lens
.
at
(
2
);
// bias shape
:
(3 * hidden_size)
// bias shape
=
(3 * hidden_size)
auto
bias_lens
=
bias
->
get_shape
().
lens
();
auto
hidden_size
=
bias_lens
.
at
(
0
)
/
3
;
auto
head_size
=
hidden_size
/
num_heads
;
int
past_sequence_length
=
0
;
// GetPresent
// Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std
::
vector
<
std
::
size_t
>
present_lens
{
2
,
batch_size
,
num_heads
,
sequence_length
,
head_size
};
if
(
is_past
)
{
auto
past_lens
=
past
->
get_shape
().
lens
();
past_sequence_length
=
past_lens
.
at
(
3
);
present_lens
[
3
]
+=
past_lens
[
3
];
}
// Use GEMM for fully connection.
auto
m
=
batch_size
*
sequence_length
;
auto
n
=
bias_lens
.
front
();
auto
k
=
input_hidden_size
;
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto
bias_type
=
bias
->
get_shape
().
type
();
...
...
@@ -83,27 +53,21 @@ struct parse_attention : op_parser<parse_attention>
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
///
Use row-major =>
results(N, M) = 1 * input x weights + 1 x B
auto
input_s
q
=
info
.
add_instruction
(
/// results(N, M) = 1 * input x weights + 1 x B
auto
input_
r
s
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_s
q
,
weights
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_
r
s
,
weights
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
// LaunchTransQkv: BxSx3xNxH => 3xBxNxSxH
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
// transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH
auto
batches
=
batch_size
*
num_heads
;
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
total_size
=
batches
*
size_per_batch
;
// Q, K, V: each has size BxNxSxH
auto
q_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transqkv
);
auto
k_t
=
info
.
add_instruction
(
...
...
@@ -114,32 +78,12 @@ struct parse_attention : op_parser<parse_attention>
k_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
k_t
);
v_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
v_t
);
if
(
is_past
)
{
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
3
}}),
past
,
k_t
);
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
k_t
);
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
// sequence length.
auto
mask_index_lens
=
mask_index
->
get_shape
().
lens
();
bool
use_raw_attention_mask
=
mask_index_lens
.
size
()
>=
2
;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const
float
rsqrt_head_size
=
1.
f
/
sqrt
(
static_cast
<
float
>
(
head_size
));
const
int
all_sequence_length
=
past_sequence_length
+
sequence_length
;
const
int
temp_matrix_size
=
sequence_length
*
all_sequence_length
;
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const
float
alpha
=
use_raw_attention_mask
?
1.0
:
rsqrt_head_size
;
// compute Q*K' scaled by 1/sqrt(H)
// Q: BxNxSxH, K (present_k): BxNxSxH => Q*K': BxNxSxS
const
float
alpha
=
1.
f
/
sqrt
(
static_cast
<
float
>
(
head_size
));
// K{B,N,S,H} -> K'{B,N,H,S}
k_t
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
k_t
);
auto
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
q_t
,
k_t
);
if
(
is_extra_add_qk
)
gemm3
=
info
.
add_instruction
(
make_op
(
"add"
),
gemm3
,
extra_add_qk
);
auto
alpha_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm3
->
get_shape
().
lens
()}}),
info
.
add_literal
(
...
...
@@ -147,19 +91,18 @@ struct parse_attention : op_parser<parse_attention>
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm3
,
info
.
make_contiguous
(
alpha_lit
));
// apply softmax and store result P to scratch2: BxNxSxS*
// Inference mask is all 1s => masking can be skipped
// P = softmax result: BxNxSxS
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
// compute P*V
// compute P*V
: (BxNxSxS) x (BxNxSxH) => BxNxSxH
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSx
N*H
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSx
HiddenSize
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
gemm4
);
gemm4
=
info
.
add_instruction
(
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
return
gemm4
;
gemm4
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment