Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ba7a370a
Commit
ba7a370a
authored
May 02, 2022
by
turneram
Browse files
Formatting
parent
eea36256
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
84 additions
and
68 deletions
+84
-68
src/include/migraphx/op/layernorm.hpp
src/include/migraphx/op/layernorm.hpp
+1
-2
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+72
-58
src/onnx/parse_layernorm.cpp
src/onnx/parse_layernorm.cpp
+5
-4
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
+2
-1
src/targets/gpu/layernorm.cpp
src/targets/gpu/layernorm.cpp
+3
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
No files found.
src/include/migraphx/op/layernorm.hpp
View file @
ba7a370a
...
@@ -87,8 +87,7 @@ struct layernorm
...
@@ -87,8 +87,7 @@ struct layernorm
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
{
output
[
offset
+
i
]
=
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
/* if(args.size() == 3)
/* if(args.size() == 3)
output[offset + i] =
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
...
...
src/onnx/parse_attention.cpp
View file @
ba7a370a
...
@@ -25,12 +25,12 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -25,12 +25,12 @@ struct parse_attention : op_parser<parse_attention>
instruction_ref
extra_add_qk
;
instruction_ref
extra_add_qk
;
bool
is_past
=
false
;
bool
is_past
=
false
;
bool
is_extra_add_qk
=
false
;
bool
is_extra_add_qk
=
false
;
if
(
args
.
size
()
>
4
)
if
(
args
.
size
()
>
4
)
{
{
past
=
args
[
4
];
past
=
args
[
4
];
is_past
=
true
;
is_past
=
true
;
}
}
if
(
args
.
size
()
==
6
)
if
(
args
.
size
()
==
6
)
{
{
is_extra_add_qk
=
true
;
is_extra_add_qk
=
true
;
extra_add_qk
=
args
[
5
];
extra_add_qk
=
args
[
5
];
...
@@ -53,14 +53,14 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -53,14 +53,14 @@ struct parse_attention : op_parser<parse_attention>
auto
head_size
=
hidden_size
/
num_heads
;
auto
head_size
=
hidden_size
/
num_heads
;
int
past_sequence_length
=
0
;
int
past_sequence_length
=
0
;
// GetPresent
// GetPresent
// Input and output shapes:
// Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std
::
vector
<
std
::
size_t
>
present_lens
{
2
,
batch_size
,
num_heads
,
sequence_length
,
head_size
};
std
::
vector
<
std
::
size_t
>
present_lens
{
2
,
batch_size
,
num_heads
,
sequence_length
,
head_size
};
if
(
is_past
)
if
(
is_past
)
{
{
auto
past_lens
=
past
->
get_shape
().
lens
();
auto
past_lens
=
past
->
get_shape
().
lens
();
past_sequence_length
=
past_lens
.
at
(
3
);
past_sequence_length
=
past_lens
.
at
(
3
);
...
@@ -76,24 +76,34 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -76,24 +76,34 @@ struct parse_attention : op_parser<parse_attention>
auto
bias_type
=
bias
->
get_shape
().
type
();
auto
bias_type
=
bias
->
get_shape
().
type
();
std
::
vector
<
float
>
ones_vec
(
m
,
1
);
std
::
vector
<
float
>
ones_vec
(
m
,
1
);
std
::
vector
<
std
::
size_t
>
ones_lens
{
1
,
m
};
std
::
vector
<
std
::
size_t
>
ones_lens
{
1
,
m
};
auto
ones
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
auto
ones
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */
);
auto
gemm_1
=
info
.
add_instruction
(
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
migraphx
::
make_op
(
"dot"
),
bias
,
ones
/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */
);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
gemm_1
=
/// Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x
/// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_sq
,
weights
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_sq
,
weights
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
// LaunchAttentionKernel:
// LaunchAttentionKernel:
// LaunchTransQkv
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
std
::
vector
<
std
::
size_t
>
qkv_perm
{
2
,
0
,
3
,
1
,
4
};
std
::
vector
<
std
::
size_t
>
qkv_perm
{
2
,
0
,
3
,
1
,
4
};
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
qkv_perm
}}),
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
qkv_perm
}}),
add_gemms
);
// now scratch3 has Q, K, V: each has size BxNxSxH
// now scratch3 has Q, K, V: each has size BxNxSxH
// => transqkv has shape 3xBxNxSxH
// => transqkv has shape 3xBxNxSxH
...
@@ -101,20 +111,25 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -101,20 +111,25 @@ struct parse_attention : op_parser<parse_attention>
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
total_size
=
batches
*
size_per_batch
;
auto
total_size
=
batches
*
size_per_batch
;
auto
q_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transqkv
);
auto
q_t
=
info
.
add_instruction
(
auto
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transqkv
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transqkv
);
auto
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transqkv
);
auto
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transqkv
);
auto
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transqkv
);
q_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
q_t
);
q_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
q_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
k_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
k_t
);
v_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
v_t
);
v_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
v_t
);
if
(
is_past
)
if
(
is_past
)
{
{
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
3
}}),
past
,
k_t
);
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
3
}}),
past
,
k_t
);
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
k_t
);
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
k_t
);
}
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
// sequence length.
auto
mask_index_lens
=
mask_index
->
get_shape
().
lens
();
auto
mask_index_lens
=
mask_index
->
get_shape
().
lens
();
bool
use_raw_attention_mask
=
mask_index_lens
.
size
()
>=
2
;
bool
use_raw_attention_mask
=
mask_index_lens
.
size
()
>=
2
;
...
@@ -127,27 +142,23 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -127,27 +142,23 @@ struct parse_attention : op_parser<parse_attention>
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const
float
alpha
=
use_raw_attention_mask
?
1.0
:
rsqrt_head_size
;
const
float
alpha
=
use_raw_attention_mask
?
1.0
:
rsqrt_head_size
;
// K{B,N,S,H} -> K'{B,N,H,S}
// K{B,N,S,H} -> K'{B,N,H,S}
k_t
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
k_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
k_t
);
auto
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
q_t
,
k_t
);
auto
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
q_t
,
k_t
);
if
(
is_extra_add_qk
)
if
(
is_extra_add_qk
)
gemm3
=
info
.
add_instruction
(
make_op
(
"add"
),
gemm3
,
extra_add_qk
);
gemm3
=
info
.
add_instruction
(
make_op
(
"add"
),
gemm3
,
extra_add_qk
);
auto
alpha_lit
=
info
.
add_instruction
(
auto
alpha_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm3
->
get_shape
().
lens
()}}),
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm3
->
get_shape
().
lens
()}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
gemm3
->
get_shape
().
type
()},
{
alpha
}}));
info
.
add_literal
(
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm3
,
info
.
make_contiguous
(
alpha_lit
));
migraphx
::
literal
{
migraphx
::
shape
{
gemm3
->
get_shape
().
type
()},
{
alpha
}}));
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm3
,
info
.
make_contiguous
(
alpha_lit
));
// apply softmax and store result P to scratch2: BxNxSxS*
// apply softmax and store result P to scratch2: BxNxSxS*
std
::
vector
<
float
>
mask
(
batch_size
*
num_heads
*
sequence_length
*
all_sequence_length
,
0
);
std
::
vector
<
float
>
mask
(
batch_size
*
num_heads
*
sequence_length
*
all_sequence_length
,
0
);
if
(
false
and
mask_index_lens
.
size
()
>=
2
)
if
(
false
and
mask_index_lens
.
size
()
>=
2
)
{}
else
if
(
false
and
mask_index_lens
.
size
()
==
1
)
{
{
}
else
if
(
false
and
mask_index_lens
.
size
()
==
1
)
{
}
}
// else => no mask
// else => no mask
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
...
@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention>
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
// scratch3 is BxNxSxH, transpose to output BxSxNxH
// scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
gemm4
);
gemm4
=
info
.
add_instruction
(
gemm4
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
gemm4
);
gemm4
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
return
gemm4
;
return
gemm4
;
}
}
};
};
...
...
src/onnx/parse_layernorm.cpp
View file @
ba7a370a
...
@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
...
@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
}
}
auto
layernorm
=
info
.
add_instruction
(
make_op
(
"layernorm"
,
{{
"epsilon"
,
epsilon
},
{
"axis"
,
axis
}}),
args
.
front
());
auto
layernorm
=
info
.
add_instruction
(
make_op
(
"layernorm"
,
{{
"epsilon"
,
epsilon
},
{
"axis"
,
axis
}}),
args
.
front
());
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
layernorm
=
info
.
add_instruction
(
make_op
(
"mul"
),
layernorm
,
args
.
at
(
1
));
layernorm
=
info
.
add_instruction
(
make_op
(
"mul"
),
layernorm
,
args
.
at
(
1
));
layernorm
=
info
.
add_instruction
(
make_op
(
"add"
),
layernorm
,
args
.
at
(
2
));
layernorm
=
info
.
add_instruction
(
make_op
(
"add"
),
layernorm
,
args
.
at
(
2
));
...
...
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
View file @
ba7a370a
...
@@ -12,7 +12,8 @@ namespace device {
...
@@ -12,7 +12,8 @@ namespace device {
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
);
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
);
//void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3, const int64_t axis);
// void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument&
// arg2, const argument& arg3, const int64_t axis);
void
triadd_layernorm
(
hipStream_t
stream
,
void
triadd_layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
...
...
src/targets/gpu/layernorm.cpp
View file @
ba7a370a
...
@@ -19,7 +19,8 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar
...
@@ -19,7 +19,8 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar
{
{
auto n_dim = args.front().get_shape().lens().size();
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], tuned_axis);
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2],
tuned_axis);
}
}
else */
else */
std
::
cout
<<
"calling device::ln"
<<
std
::
endl
;
std
::
cout
<<
"calling device::ln"
<<
std
::
endl
;
...
...
src/targets/gpu/lowering.cpp
View file @
ba7a370a
...
@@ -389,7 +389,7 @@ struct miopen_apply
...
@@ -389,7 +389,7 @@ struct miopen_apply
apply_map
.
emplace
(
op_name
,
[
=
](
instruction_ref
ins
)
{
apply_map
.
emplace
(
op_name
,
[
=
](
instruction_ref
ins
)
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
if
(
op_name
==
"layernorm"
)
if
(
op_name
==
"layernorm"
)
{
{
std
::
cout
<<
"layernorm op"
<<
std
::
endl
;
std
::
cout
<<
"layernorm op"
<<
std
::
endl
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment