Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5af79bd7
Commit
5af79bd7
authored
May 31, 2022
by
turneram
Browse files
Remove layernorm op
parent
89068ad1
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
0 additions
and
323 deletions
+0
-323
src/CMakeLists.txt
src/CMakeLists.txt
+0
-1
src/include/migraphx/op/layernorm.hpp
src/include/migraphx/op/layernorm.hpp
+0
-124
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-1
src/onnx/parse_layernorm.cpp
src/onnx/parse_layernorm.cpp
+0
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+0
-2
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
+0
-40
src/targets/gpu/layernorm.cpp
src/targets/gpu/layernorm.cpp
+0
-24
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-3
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+0
-16
test/onnx/layernorm_op_test.onnx
test/onnx/layernorm_op_test.onnx
+0
-24
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+0
-25
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+0
-44
test/verify/test_layernorm_op.cpp
test/verify/test_layernorm_op.cpp
+0
-18
No files found.
src/CMakeLists.txt
View file @
5af79bd7
...
@@ -117,7 +117,6 @@ register_migraphx_ops(
...
@@ -117,7 +117,6 @@ register_migraphx_ops(
if_op
if_op
im2col
im2col
isnan
isnan
layernorm
leaky_relu
leaky_relu
less
less
load
load
...
...
src/include/migraphx/op/layernorm.hpp
deleted
100644 → 0
View file @
89068ad1
#ifndef MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#define MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
layernorm
{
float
epsilon
=
1e-3
;
int64_t
axis
=
-
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
axis
,
"axis"
));
}
value
attributes
()
const
{
value
normalize
;
normalize
[
"axis"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"layernorm"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
==
2
)
{
if
(
inputs
.
at
(
1
).
lens
().
front
()
!=
inputs
.
front
().
lens
().
at
(
axis
))
MIGRAPHX_THROW
(
"LAYERNORM: weights have wrong shape"
);
}
if
(
inputs
.
size
()
==
3
)
{
if
(
inputs
.
at
(
2
).
lens
().
front
()
!=
inputs
.
front
().
lens
().
at
(
axis
))
MIGRAPHX_THROW
(
"LAYERNORM: bias has wrong shape"
);
}
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
x_lens
=
args
.
front
().
get_shape
().
lens
();
auto
norm_count
=
std
::
accumulate
(
x_lens
.
begin
(),
x_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
norm_size
=
std
::
accumulate
(
x_lens
.
begin
()
+
axis
,
x_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
args
.
size
()
==
3
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
],
args
[
2
])(
[
&
](
auto
output
,
auto
data
,
auto
weights
,
auto
bias
)
{
par_for
(
norm_count
,
[
&
](
auto
idx
)
{
auto
offset
=
idx
*
norm_size
;
double
mean
=
0
;
double
mean_square
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
mean
+=
data
[
offset
+
i
];
mean_square
+=
data
[
offset
+
i
]
*
data
[
offset
+
i
];
}
mean
/=
norm_size
;
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
if
(
args
.
size
()
==
3
)
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
*
weights
[
i
]
+
bias
[
i
];
else
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
*
weights
[
i
];
}
});
});
}
else
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
par_for
(
norm_count
,
[
&
](
auto
idx
)
{
auto
offset
=
idx
*
norm_size
;
double
mean
=
0
;
double
mean_square
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
mean
+=
data
[
offset
+
i
];
mean_square
+=
data
[
offset
+
i
]
*
data
[
offset
+
i
];
}
mean
/=
norm_size
;
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
// scale and bias handled by pointwise ops
}
});
});
}
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
5af79bd7
...
@@ -43,7 +43,6 @@
...
@@ -43,7 +43,6 @@
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/load.hpp>
...
...
src/onnx/parse_layernorm.cpp
View file @
5af79bd7
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
...
src/targets/gpu/CMakeLists.txt
View file @
5af79bd7
...
@@ -148,7 +148,6 @@ add_library(migraphx_gpu
...
@@ -148,7 +148,6 @@ add_library(migraphx_gpu
int8_conv_pack.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
kernel.cpp
layernorm.cpp
lowering.cpp
lowering.cpp
logsoftmax.cpp
logsoftmax.cpp
loop.cpp
loop.cpp
...
@@ -205,7 +204,6 @@ register_migraphx_gpu_ops(hip_
...
@@ -205,7 +204,6 @@ register_migraphx_gpu_ops(hip_
floor
floor
gather
gather
greater
greater
layernorm
less
less
log
log
logsoftmax
logsoftmax
...
...
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
deleted
100644 → 0
View file @
89068ad1
#ifndef MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#include <migraphx/op/layernorm.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/argument.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_layernorm
{
op
::
layernorm
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::layernorm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
);
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/layernorm.cpp
deleted
100644 → 0
View file @
89068ad1
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_layernorm
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
hip_layernorm
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
layernorm
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]);
return
args
.
back
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/lowering.cpp
View file @
5af79bd7
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
...
@@ -30,7 +29,6 @@
...
@@ -30,7 +29,6 @@
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/logical_and.hpp>
#include <migraphx/gpu/logical_and.hpp>
...
@@ -141,7 +139,6 @@ struct miopen_apply
...
@@ -141,7 +139,6 @@ struct miopen_apply
add_generic_op
(
"exp"
);
add_generic_op
(
"exp"
);
add_generic_op
(
"floor"
);
add_generic_op
(
"floor"
);
add_generic_op
(
"greater"
);
add_generic_op
(
"greater"
);
add_generic_op
(
"layernorm"
);
add_generic_op
(
"less"
);
add_generic_op
(
"less"
);
add_generic_op
(
"log"
);
add_generic_op
(
"log"
);
add_generic_op
(
"logical_and"
);
add_generic_op
(
"logical_and"
);
...
...
test/onnx/gen_onnx.py
View file @
5af79bd7
...
@@ -2644,22 +2644,6 @@ def layernorm_test():
...
@@ -2644,22 +2644,6 @@ def layernorm_test():
bias_add
],
[
x
,
scale
,
bias
],
[
y
],
[
pow_tensor
,
epsilon_tensor
])
bias_add
],
[
x
,
scale
,
bias
],
[
y
],
[
pow_tensor
,
epsilon_tensor
])
@
onnx_test
def
layernorm_op_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
1
,
2
,
3
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
3
])
b
=
helper
.
make_tensor_value_info
(
'b'
,
TensorProto
.
FLOAT
,
[
3
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
1
,
2
,
3
])
node
=
onnx
.
helper
.
make_node
(
'LayerNormalization'
,
inputs
=
[
'x'
,
'w'
,
'b'
],
outputs
=
[
"output"
],
epsilon
=
1e-5
)
return
([
node
],
[
x
,
w
,
b
],
[
output
])
@
onnx_test
@
onnx_test
def
leaky_relu_test
():
def
leaky_relu_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
test/onnx/layernorm_op_test.onnx
deleted
100644 → 0
View file @
89068ad1
layernorm_op_test:
8
x
w
boutput"LayerNormalization*
epsilon'7layernorm_op_testZ
x
Z
w
Z
b
b
output
B
\ No newline at end of file
test/onnx/verify_onnx.cpp
View file @
5af79bd7
...
@@ -472,31 +472,6 @@ TEST_CASE(instance_norm_3d_test)
...
@@ -472,31 +472,6 @@ TEST_CASE(instance_norm_3d_test)
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
}
TEST_CASE
(
layernorm_op_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"layernorm_op_test.onnx"
);
p
.
compile
(
migraphx
::
ref
::
target
{});
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
swb
{
migraphx
::
shape
::
float_type
,
{
3
}};
std
::
vector
<
float
>
x_vec
{
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
};
std
::
vector
<
float
>
w_vec
{
1.0
,
1.0
,
1.0
};
std
::
vector
<
float
>
b_vec
{
0.0
,
0.0
,
0.0
};
migraphx
::
parameter_map
pp
;
pp
[
"x"
]
=
migraphx
::
argument
(
sx
,
x_vec
.
data
());
pp
[
"w"
]
=
migraphx
::
argument
(
swb
,
w_vec
.
data
());
pp
[
"b"
]
=
migraphx
::
argument
(
swb
,
b_vec
.
data
());
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
result_vector
(
6
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
-
1.22474
f
,
0.0
f
,
1.22474
f
,
-
1.22474
f
,
0.0
f
,
1.22474
f
};
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
lessorequal_test
)
TEST_CASE
(
lessorequal_test
)
{
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"lessorequal_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"lessorequal_test.onnx"
);
...
...
test/ref_ops_test.cpp
View file @
5af79bd7
...
@@ -2435,50 +2435,6 @@ TEST_CASE(imagescaler_test)
...
@@ -2435,50 +2435,6 @@ TEST_CASE(imagescaler_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
layernorm_test
)
{
{
// with scale and bias
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
migraphx
::
shape
swb
{
migraphx
::
shape
::
float_type
,
{
3
}};
std
::
vector
<
float
>
x_vec
{
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
};
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
{
sx
,
x_vec
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
swb
,
{
1.0
,
1.0
,
1.0
}});
auto
b
=
mm
->
add_literal
(
migraphx
::
literal
{
swb
,
{
0.0
,
0.0
,
0.0
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"layernorm"
,
{{
"epsilon"
,
1e-5
}}),
x
,
w
,
b
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
(
1
*
2
*
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
1.22474
f
,
0.0
f
,
1.22474
f
,
-
1.22474
f
,
0.0
f
,
1.22474
f
};
for
(
auto
&&
i
:
results_vector
)
std
::
cout
<<
i
<<
", "
;
std
::
cout
<<
std
::
endl
;
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
{
// without scale and bias
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sx
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
std
::
vector
<
float
>
x_vec
{
1.0
,
2.0
,
3.0
,
4.0
,
5.0
,
6.0
};
auto
x
=
mm
->
add_literal
(
migraphx
::
literal
{
sx
,
x_vec
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"layernorm"
,
{{
"epsilon"
,
1e-5
}}),
x
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
(
1
*
2
*
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
1.22474
f
,
0.0
f
,
1.22474
f
,
-
1.22474
f
,
0.0
f
,
1.22474
f
};
for
(
auto
&&
i
:
results_vector
)
std
::
cout
<<
i
<<
", "
;
std
::
cout
<<
std
::
endl
;
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE
(
leaky_relu_test
)
TEST_CASE
(
leaky_relu_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/verify/test_layernorm_op.cpp
deleted
100644 → 0
View file @
89068ad1
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_layernorm_op
:
verify_program
<
test_layernorm_op
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
768
}});
mm
->
add_instruction
(
migraphx
::
make_op
(
"layernorm"
,
{{
"axis"
,
-
1
},
{
"epsilon"
,
1e-12
}}),
x
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment