Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1974671d
Commit
1974671d
authored
May 17, 2022
by
turneram
Browse files
Formatting
parent
fe9a42f1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
46 deletions
+48
-46
src/include/migraphx/op/transposectx.hpp
src/include/migraphx/op/transposectx.hpp
+5
-5
src/include/migraphx/op/transposeqkv.hpp
src/include/migraphx/op/transposeqkv.hpp
+7
-6
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+4
-8
src/targets/gpu/jit/bert_transpose.cpp
src/targets/gpu/jit/bert_transpose.cpp
+4
-2
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
+8
-8
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
+8
-8
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+12
-9
No files found.
src/include/migraphx/op/transposectx.hpp
View file @
1974671d
...
@@ -48,15 +48,15 @@ struct transposectx
...
@@ -48,15 +48,15 @@ struct transposectx
int
num_heads
=
lens
.
at
(
1
);
int
num_heads
=
lens
.
at
(
1
);
int
sequence_length
=
lens
.
at
(
2
);
int
sequence_length
=
lens
.
at
(
2
);
int
head_size
=
lens
.
back
();
int
head_size
=
lens
.
back
();
const
int
NH
=
num_heads
*
head_size
;
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
NHS
=
NH
*
sequence_length
;
//const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
//
const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
const
int
j
=
idx
.
back
();
const
int
j
=
idx
.
back
();
output
[
out_offset
+
j
]
=
input
[
i
];
output
[
out_offset
+
j
]
=
input
[
i
];
});
});
});
});
...
...
src/include/migraphx/op/transposeqkv.hpp
View file @
1974671d
...
@@ -33,7 +33,7 @@ struct transposeqkv
...
@@ -33,7 +33,7 @@ struct transposeqkv
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
// Input: BxSxKxNxH
// Input: BxSxKxNxH
// Output: KxBxNxSxH
// Output: KxBxNxSxH
// K is the number of identical matrix
// K is the number of identical matrix
...
@@ -53,12 +53,13 @@ struct transposeqkv
...
@@ -53,12 +53,13 @@ struct transposeqkv
const
int
num_heads
=
lens
[
3
];
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
sequence_length
=
lens
[
1
];
const
int
batch_size
=
lens
[
0
];
const
int
batch_size
=
lens
[
0
];
const
int
H
=
lens
.
back
();
const
int
H
=
lens
.
back
();
const
int
NH
=
num_heads
*
H
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
output
[
out_offset
+
j
]
=
input
[
i
];
output
[
out_offset
+
j
]
=
input
[
i
];
});
});
...
...
src/onnx/parse_attention.cpp
View file @
1974671d
...
@@ -79,14 +79,11 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -79,14 +79,11 @@ struct parse_attention : op_parser<parse_attention>
auto
ones
=
auto
ones
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
auto
gemm_1
=
info
.
add_instruction
(
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
);
migraphx
::
make_op
(
"dot"
),
bias
,
ones
);
gemm_1
=
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B
auto
input_sq
=
info
.
add_instruction
(
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
input
);
...
@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention>
migraphx
::
make_op
(
"reshape"
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
// transqkv has shape 3xBxNxSxH
// transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH
// => Q, K, V: each has size BxNxSxH
...
@@ -155,7 +151,7 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -155,7 +151,7 @@ struct parse_attention : op_parser<parse_attention>
// Inference mask is all 1s => masking can be skipped
// Inference mask is all 1s => masking can be skipped
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
// compute P*V
// compute P*V
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
...
...
src/targets/gpu/jit/bert_transpose.cpp
View file @
1974671d
...
@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler>
...
@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposectx_kernel"
;
options
.
kernel_name
=
"transposectx_kernel"
;
...
@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
...
@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposeqkv_kernel"
;
options
.
kernel_name
=
"transposeqkv_kernel"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
View file @
1974671d
...
@@ -11,12 +11,12 @@ __device__ void transposectx(const T& input_t, const U& output_t)
...
@@ -11,12 +11,12 @@ __device__ void transposectx(const T& input_t, const U& output_t)
{
{
// Input: BxNxSxH
// Input: BxNxSxH
// Output: BxSxNxH
// Output: BxSxNxH
auto
index
=
make_index
();
auto
index
=
make_index
();
auto
input_shape
=
input_t
.
get_shape
();
auto
input_shape
=
input_t
.
get_shape
();
auto
lens
=
input_shape
.
lens
;
auto
lens
=
input_shape
.
lens
;
const
int
num_heads
=
lens
[
1
];
const
int
num_heads
=
lens
[
1
];
const
int
sequence_length
=
lens
[
2
];
const
int
sequence_length
=
lens
[
2
];
int
head_size
=
lens
[
3
];
int
head_size
=
lens
[
3
];
auto
idx
=
input_shape
.
multi
(
index
.
global
);
auto
idx
=
input_shape
.
multi
(
index
.
global
);
...
@@ -24,11 +24,11 @@ __device__ void transposectx(const T& input_t, const U& output_t)
...
@@ -24,11 +24,11 @@ __device__ void transposectx(const T& input_t, const U& output_t)
int
s
=
idx
[
2
];
int
s
=
idx
[
2
];
int
b
=
idx
[
0
];
int
b
=
idx
[
0
];
const
int
NH
=
num_heads
*
head_size
;
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
if
(
index
.
local
<
1024
)
if
(
index
.
local
<
1024
)
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
View file @
1974671d
...
@@ -13,9 +13,9 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
...
@@ -13,9 +13,9 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
// Output: KxBxNxSxH
// Output: KxBxNxSxH
// K is the number of identical matrix
// K is the number of identical matrix
auto
index
=
make_index
();
auto
index
=
make_index
();
auto
input_shape
=
input_t
.
get_shape
();
auto
input_shape
=
input_t
.
get_shape
();
auto
lens
=
input_shape
.
lens
;
auto
lens
=
input_shape
.
lens
;
auto
idx
=
input_shape
.
multi
(
index
.
global
);
auto
idx
=
input_shape
.
multi
(
index
.
global
);
...
@@ -23,14 +23,14 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
...
@@ -23,14 +23,14 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
const
int
s
=
idx
[
1
];
const
int
s
=
idx
[
1
];
const
int
m
=
idx
[
2
];
const
int
m
=
idx
[
2
];
const
int
n
=
idx
[
3
];
const
int
n
=
idx
[
3
];
//const int j = idx[4];
//
const int j = idx[4];
const
int
num_heads
=
lens
[
3
];
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
sequence_length
=
lens
[
1
];
const
int
batch_size
=
lens
[
0
];
const
int
batch_size
=
lens
[
0
];
const
int
H
=
lens
[
4
];
const
int
H
=
lens
[
4
];
const
int
NH
=
num_heads
*
H
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
...
...
test/ref_ops_test.cpp
View file @
1974671d
...
@@ -671,12 +671,12 @@ TEST_CASE(bert_transpose_ops_test)
...
@@ -671,12 +671,12 @@ TEST_CASE(bert_transpose_ops_test)
{
{
// transposeQKV
// transposeQKV
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
const
int
k
=
3
,
b
=
1
,
n
=
2
,
s
=
2
,
h
=
1
;
const
int
k
=
3
,
b
=
1
,
n
=
2
,
s
=
2
,
h
=
1
;
migraphx
::
shape
sh
{
migraphx
::
shape
::
float_type
,
{
b
,
s
,
k
,
n
,
h
}};
migraphx
::
shape
sh
{
migraphx
::
shape
::
float_type
,
{
b
,
s
,
k
,
n
,
h
}};
std
::
vector
<
float
>
data
(
b
*
s
*
k
*
n
*
h
);
std
::
vector
<
float
>
data
(
b
*
s
*
k
*
n
*
h
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
l1
);
p
.
compile
(
migraphx
::
ref
::
target
{});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
...
@@ -686,14 +686,15 @@ TEST_CASE(bert_transpose_ops_test)
...
@@ -686,14 +686,15 @@ TEST_CASE(bert_transpose_ops_test)
migraphx
::
program
p2
;
migraphx
::
program
p2
;
auto
*
mm2
=
p2
.
get_main_module
();
auto
*
mm2
=
p2
.
get_main_module
();
auto
l2
=
mm2
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
auto
l2
=
mm2
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
mm2
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
l2
);
mm2
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
l2
);
p2
.
compile
(
migraphx
::
ref
::
target
{});
p2
.
compile
(
migraphx
::
ref
::
target
{});
auto
result2
=
p2
.
eval
({}).
back
();
auto
result2
=
p2
.
eval
({}).
back
();
std
::
vector
<
float
>
result_vector2
(
k
*
b
*
n
*
s
*
h
);
std
::
vector
<
float
>
result_vector2
(
k
*
b
*
n
*
s
*
h
);
result2
.
visit
([
&
](
auto
output
)
{
result_vector2
.
assign
(
output
.
begin
(),
output
.
end
());
});
result2
.
visit
([
&
](
auto
output
)
{
result_vector2
.
assign
(
output
.
begin
(),
output
.
end
());
});
for
(
auto
&
i
:
result_vector2
)
for
(
auto
&
i
:
result_vector2
)
std
::
cout
<<
i
<<
", "
;
std
::
cout
<<
i
<<
", "
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
...
@@ -706,7 +707,7 @@ TEST_CASE(bert_transpose_ops_test)
...
@@ -706,7 +707,7 @@ TEST_CASE(bert_transpose_ops_test)
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
,
2
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
2
,
2
}};
std
::
vector
<
float
>
data
(
2
*
2
*
2
*
2
);
std
::
vector
<
float
>
data
(
2
*
2
*
2
*
2
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
s
,
data
});
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
s
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
l1
);
p
.
compile
(
migraphx
::
ref
::
target
{});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
...
@@ -719,18 +720,20 @@ TEST_CASE(bert_transpose_ops_test)
...
@@ -719,18 +720,20 @@ TEST_CASE(bert_transpose_ops_test)
{
{
// transposeCtx
// transposeCtx
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
const
int
b
=
2
,
n
=
2
,
s
=
3
,
h
=
4
;
const
int
b
=
2
,
n
=
2
,
s
=
3
,
h
=
4
;
migraphx
::
shape
sh
{
migraphx
::
shape
::
float_type
,
{
b
,
n
,
s
,
h
}};
migraphx
::
shape
sh
{
migraphx
::
shape
::
float_type
,
{
b
,
n
,
s
,
h
}};
std
::
vector
<
float
>
data
(
b
*
n
*
s
*
h
);
std
::
vector
<
float
>
data
(
b
*
n
*
s
*
h
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
l1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
l1
);
p
.
compile
(
migraphx
::
ref
::
target
{});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
result_vector
(
b
*
n
*
s
*
h
);
std
::
vector
<
float
>
result_vector
(
b
*
n
*
s
*
h
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
0
,
1
,
2
,
3
,
12
,
13
,
14
,
15
,
4
,
5
,
6
,
7
,
16
,
17
,
18
,
19
,
8
,
9
,
10
,
11
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
36
,
37
,
38
,
39
,
28
,
29
,
30
,
31
,
40
,
41
,
42
,
43
,
32
,
33
,
34
,
35
,
44
,
45
,
46
,
47
};
std
::
vector
<
float
>
gold
{
0
,
1
,
2
,
3
,
12
,
13
,
14
,
15
,
4
,
5
,
6
,
7
,
16
,
17
,
18
,
19
,
8
,
9
,
10
,
11
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
36
,
37
,
38
,
39
,
28
,
29
,
30
,
31
,
40
,
41
,
42
,
43
,
32
,
33
,
34
,
35
,
44
,
45
,
46
,
47
};
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment