Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ba7a370a
Commit
ba7a370a
authored
May 02, 2022
by
turneram
Browse files
Formatting
parent
eea36256
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
85 additions
and
69 deletions
+85
-69
src/include/migraphx/op/layernorm.hpp
src/include/migraphx/op/layernorm.hpp
+1
-2
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+72
-58
src/onnx/parse_layernorm.cpp
src/onnx/parse_layernorm.cpp
+5
-4
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
+2
-1
src/targets/gpu/layernorm.cpp
src/targets/gpu/layernorm.cpp
+3
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-1
No files found.
src/include/migraphx/op/layernorm.hpp
View file @
ba7a370a
...
@@ -87,8 +87,7 @@ struct layernorm
...
@@ -87,8 +87,7 @@ struct layernorm
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
{
output
[
offset
+
i
]
=
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
/* if(args.size() == 3)
/* if(args.size() == 3)
output[offset + i] =
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
...
...
src/onnx/parse_attention.cpp
View file @
ba7a370a
...
@@ -15,25 +15,25 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -15,25 +15,25 @@ struct parse_attention : op_parser<parse_attention>
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
auto
input
=
args
[
0
];
auto
input
=
args
[
0
];
auto
weights
=
args
[
1
];
auto
weights
=
args
[
1
];
auto
bias
=
args
[
2
];
auto
bias
=
args
[
2
];
auto
mask_index
=
args
[
3
];
auto
mask_index
=
args
[
3
];
instruction_ref
past
;
instruction_ref
past
;
instruction_ref
extra_add_qk
;
instruction_ref
extra_add_qk
;
bool
is_past
=
false
;
bool
is_past
=
false
;
bool
is_extra_add_qk
=
false
;
bool
is_extra_add_qk
=
false
;
if
(
args
.
size
()
>
4
)
if
(
args
.
size
()
>
4
)
{
{
past
=
args
[
4
];
past
=
args
[
4
];
is_past
=
true
;
is_past
=
true
;
}
}
if
(
args
.
size
()
==
6
)
if
(
args
.
size
()
==
6
)
{
{
is_extra_add_qk
=
true
;
is_extra_add_qk
=
true
;
extra_add_qk
=
args
[
5
];
extra_add_qk
=
args
[
5
];
}
}
// ORT default is 12
// ORT default is 12
...
@@ -42,112 +42,123 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -42,112 +42,123 @@ struct parse_attention : op_parser<parse_attention>
num_heads
=
info
.
attributes
.
at
(
"num_heads"
).
i
();
num_heads
=
info
.
attributes
.
at
(
"num_heads"
).
i
();
// input shape: (batch_size, sequence_length, input_hidden_size)
// input shape: (batch_size, sequence_length, input_hidden_size)
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
input_lens
=
input
->
get_shape
().
lens
();
auto
batch_size
=
input_lens
.
at
(
0
);
auto
batch_size
=
input_lens
.
at
(
0
);
auto
sequence_length
=
input_lens
.
at
(
1
);
auto
sequence_length
=
input_lens
.
at
(
1
);
auto
input_hidden_size
=
input_lens
.
at
(
2
);
auto
input_hidden_size
=
input_lens
.
at
(
2
);
// bias shape: (3 * hidden_size)
// bias shape: (3 * hidden_size)
auto
bias_lens
=
bias
->
get_shape
().
lens
();
auto
bias_lens
=
bias
->
get_shape
().
lens
();
auto
hidden_size
=
bias_lens
.
at
(
0
)
/
3
;
auto
hidden_size
=
bias_lens
.
at
(
0
)
/
3
;
auto
head_size
=
hidden_size
/
num_heads
;
auto
head_size
=
hidden_size
/
num_heads
;
int
past_sequence_length
=
0
;
int
past_sequence_length
=
0
;
// GetPresent
// GetPresent
// Input and output shapes:
// Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std
::
vector
<
std
::
size_t
>
present_lens
{
2
,
batch_size
,
num_heads
,
sequence_length
,
head_size
};
std
::
vector
<
std
::
size_t
>
present_lens
{
2
,
batch_size
,
num_heads
,
sequence_length
,
head_size
};
if
(
is_past
)
if
(
is_past
)
{
{
auto
past_lens
=
past
->
get_shape
().
lens
();
auto
past_lens
=
past
->
get_shape
().
lens
();
past_sequence_length
=
past_lens
.
at
(
3
);
past_sequence_length
=
past_lens
.
at
(
3
);
present_lens
[
3
]
+=
past_lens
[
3
];
present_lens
[
3
]
+=
past_lens
[
3
];
}
}
// Use GEMM for fully connection.
// Use GEMM for fully connection.
auto
m
=
batch_size
*
sequence_length
;
auto
m
=
batch_size
*
sequence_length
;
auto
n
=
bias_lens
.
front
();
auto
n
=
bias_lens
.
front
();
auto
k
=
input_hidden_size
;
auto
k
=
input_hidden_size
;
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto
bias_type
=
bias
->
get_shape
().
type
();
auto
bias_type
=
bias
->
get_shape
().
type
();
std
::
vector
<
float
>
ones_vec
(
m
,
1
);
std
::
vector
<
float
>
ones_vec
(
m
,
1
);
std
::
vector
<
std
::
size_t
>
ones_lens
{
1
,
m
};
std
::
vector
<
std
::
size_t
>
ones_lens
{
1
,
m
};
auto
ones
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
auto
ones
=
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */
);
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
ones
/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */
);
/// Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
gemm_1
=
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_sq
,
weights
);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x
/// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
auto
input_sq
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
*
sequence_length
,
hidden_size
}}}),
input
);
auto
gemm_2
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input_sq
,
weights
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
auto
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm_1
,
gemm_2
);
// LaunchAttentionKernel:
// LaunchAttentionKernel:
// LaunchTransQkv
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
add_gemms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
std
::
vector
<
std
::
size_t
>
qkv_perm
{
2
,
0
,
3
,
1
,
4
};
std
::
vector
<
std
::
size_t
>
qkv_perm
{
2
,
0
,
3
,
1
,
4
};
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
qkv_perm
}}),
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
qkv_perm
}}),
add_gemms
);
// now scratch3 has Q, K, V: each has size BxNxSxH
// now scratch3 has Q, K, V: each has size BxNxSxH
// => transqkv has shape 3xBxNxSxH
// => transqkv has shape 3xBxNxSxH
auto
batches
=
batch_size
*
num_heads
;
auto
batches
=
batch_size
*
num_heads
;
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
size_per_batch
=
sequence_length
*
head_size
;
auto
total_size
=
batches
*
size_per_batch
;
auto
total_size
=
batches
*
size_per_batch
;
auto
q_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transqkv
);
auto
q_t
=
info
.
add_instruction
(
auto
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transqkv
);
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transqkv
);
auto
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transqkv
);
auto
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transqkv
);
auto
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transqkv
);
q_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
q_t
);
q_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
q_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
k_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
k_t
);
v_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
v_t
);
v_t
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
v_t
);
if
(
is_past
)
if
(
is_past
)
{
{
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
3
}}),
past
,
k_t
);
k_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
3
}}),
past
,
k_t
);
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
k_t
);
v_t
=
info
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
k_t
);
}
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
auto
mask_index_lens
=
mask_index
->
get_shape
().
lens
();
// sequence length.
auto
mask_index_lens
=
mask_index
->
get_shape
().
lens
();
bool
use_raw_attention_mask
=
mask_index_lens
.
size
()
>=
2
;
bool
use_raw_attention_mask
=
mask_index_lens
.
size
()
>=
2
;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const
float
rsqrt_head_size
=
1.
f
/
sqrt
(
static_cast
<
float
>
(
head_size
));
const
float
rsqrt_head_size
=
1.
f
/
sqrt
(
static_cast
<
float
>
(
head_size
));
const
int
all_sequence_length
=
past_sequence_length
+
sequence_length
;
const
int
all_sequence_length
=
past_sequence_length
+
sequence_length
;
const
int
temp_matrix_size
=
sequence_length
*
all_sequence_length
;
const
int
temp_matrix_size
=
sequence_length
*
all_sequence_length
;
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const
float
alpha
=
use_raw_attention_mask
?
1.0
:
rsqrt_head_size
;
const
float
alpha
=
use_raw_attention_mask
?
1.0
:
rsqrt_head_size
;
// K{B,N,S,H} -> K'{B,N,H,S}
// K{B,N,S,H} -> K'{B,N,H,S}
k_t
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
k_t
);
k_t
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
k_t
);
auto
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
q_t
,
k_t
);
auto
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
q_t
,
k_t
);
if
(
is_extra_add_qk
)
if
(
is_extra_add_qk
)
gemm3
=
info
.
add_instruction
(
make_op
(
"add"
),
gemm3
,
extra_add_qk
);
gemm3
=
info
.
add_instruction
(
make_op
(
"add"
),
gemm3
,
extra_add_qk
);
auto
alpha_lit
=
info
.
add_instruction
(
auto
alpha_lit
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm3
->
get_shape
().
lens
()}}),
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm3
->
get_shape
().
lens
()}}),
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
gemm3
->
get_shape
().
type
()},
{
alpha
}}));
info
.
add_literal
(
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm3
,
info
.
make_contiguous
(
alpha_lit
));
migraphx
::
literal
{
migraphx
::
shape
{
gemm3
->
get_shape
().
type
()},
{
alpha
}}));
gemm3
=
info
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm3
,
info
.
make_contiguous
(
alpha_lit
));
// apply softmax and store result P to scratch2: BxNxSxS*
// apply softmax and store result P to scratch2: BxNxSxS*
std
::
vector
<
float
>
mask
(
batch_size
*
num_heads
*
sequence_length
*
all_sequence_length
,
0
);
std
::
vector
<
float
>
mask
(
batch_size
*
num_heads
*
sequence_length
*
all_sequence_length
,
0
);
if
(
false
and
mask_index_lens
.
size
()
>=
2
)
if
(
false
and
mask_index_lens
.
size
()
>=
2
)
{}
{
else
if
(
false
and
mask_index_lens
.
size
()
==
1
)
}
else
if
(
false
and
mask_index_lens
.
size
()
==
1
)
{
{
}
}
// else => no mask
// else => no mask
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
auto
softmax
=
info
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
3
}}),
gemm3
);
...
@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention>
...
@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention>
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
auto
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
v_t
);
// scratch3 is BxNxSxH, transpose to output BxSxNxH
// scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
gemm4
);
gemm4
=
info
.
add_instruction
(
gemm4
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
gemm4
);
gemm4
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
num_heads
*
head_size
}}}),
info
.
make_contiguous
(
gemm4
));
return
gemm4
;
return
gemm4
;
}
}
};
};
...
...
src/onnx/parse_layernorm.cpp
View file @
ba7a370a
...
@@ -17,7 +17,7 @@ struct parse_layernorm : op_parser<parse_layernorm>
...
@@ -17,7 +17,7 @@ struct parse_layernorm : op_parser<parse_layernorm>
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
float
epsilon
=
1e-3
f
;
float
epsilon
=
1e-3
f
;
int64_t
axis
=
-
1
;
int64_t
axis
=
-
1
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
...
@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
...
@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
}
}
auto
layernorm
=
info
.
add_instruction
(
make_op
(
"layernorm"
,
{{
"epsilon"
,
epsilon
},
{
"axis"
,
axis
}}),
args
.
front
());
auto
layernorm
=
info
.
add_instruction
(
make_op
(
"layernorm"
,
{{
"epsilon"
,
epsilon
},
{
"axis"
,
axis
}}),
args
.
front
());
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
layernorm
=
info
.
add_instruction
(
make_op
(
"mul"
),
layernorm
,
args
.
at
(
1
));
layernorm
=
info
.
add_instruction
(
make_op
(
"mul"
),
layernorm
,
args
.
at
(
1
));
layernorm
=
info
.
add_instruction
(
make_op
(
"add"
),
layernorm
,
args
.
at
(
2
));
layernorm
=
info
.
add_instruction
(
make_op
(
"add"
),
layernorm
,
args
.
at
(
2
));
...
...
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
View file @
ba7a370a
...
@@ -12,7 +12,8 @@ namespace device {
...
@@ -12,7 +12,8 @@ namespace device {
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
);
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
);
//void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3, const int64_t axis);
// void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument&
// arg2, const argument& arg3, const int64_t axis);
void
triadd_layernorm
(
hipStream_t
stream
,
void
triadd_layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
...
...
src/targets/gpu/layernorm.cpp
View file @
ba7a370a
...
@@ -19,12 +19,13 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar
...
@@ -19,12 +19,13 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar
{
{
auto n_dim = args.front().get_shape().lens().size();
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], tuned_axis);
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2],
tuned_axis);
}
}
else */
else */
std
::
cout
<<
"calling device::ln"
<<
std
::
endl
;
std
::
cout
<<
"calling device::ln"
<<
std
::
endl
;
{
{
device
::
layernorm
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]);
device
::
layernorm
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]);
std
::
cout
<<
"called device::ln"
<<
std
::
endl
;
std
::
cout
<<
"called device::ln"
<<
std
::
endl
;
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
ba7a370a
...
@@ -389,7 +389,7 @@ struct miopen_apply
...
@@ -389,7 +389,7 @@ struct miopen_apply
apply_map
.
emplace
(
op_name
,
[
=
](
instruction_ref
ins
)
{
apply_map
.
emplace
(
op_name
,
[
=
](
instruction_ref
ins
)
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
if
(
op_name
==
"layernorm"
)
if
(
op_name
==
"layernorm"
)
{
{
std
::
cout
<<
"layernorm op"
<<
std
::
endl
;
std
::
cout
<<
"layernorm op"
<<
std
::
endl
;
}
}
...
...
test/onnx/gen_onnx.py
View file @
ba7a370a
...
@@ -2608,7 +2608,7 @@ def layernorm_op_test():
...
@@ -2608,7 +2608,7 @@ def layernorm_op_test():
return
([
node
],
[
x
,
w
,
b
],
[
output
])
return
([
node
],
[
x
,
w
,
b
],
[
output
])
@
onnx_test
@
onnx_test
def
leaky_relu_test
():
def
leaky_relu_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment