Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
375c7b8d
"vscode:/vscode.git/clone" did not exist on "0c3827b3a57adf3eca759bc523087b2478300aab"
Commit
375c7b8d
authored
May 27, 2022
by
turneram
Browse files
Remove layernorm operator
parent
276dda76
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
0 additions
and
220 deletions
+0
-220
src/CMakeLists.txt
src/CMakeLists.txt
+0
-1
src/include/migraphx/op/layernorm.hpp
src/include/migraphx/op/layernorm.hpp
+0
-124
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+0
-1
src/onnx/parse_layernorm.cpp
src/onnx/parse_layernorm.cpp
+0
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+0
-2
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
+0
-26
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
+0
-40
src/targets/gpu/layernorm.cpp
src/targets/gpu/layernorm.cpp
+0
-24
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
No files found.
src/CMakeLists.txt
View file @
375c7b8d
...
@@ -117,7 +117,6 @@ register_migraphx_ops(
...
@@ -117,7 +117,6 @@ register_migraphx_ops(
if_op
if_op
im2col
im2col
isnan
isnan
layernorm
leaky_relu
leaky_relu
less
less
load
load
...
...
src/include/migraphx/op/layernorm.hpp
deleted
100644 → 0
View file @
276dda76
#ifndef MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#define MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
layernorm
{
float
epsilon
=
1e-3
;
int64_t
axis
=
-
1
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
epsilon
,
"epsilon"
),
f
(
self
.
axis
,
"axis"
));
}
value
attributes
()
const
{
value
normalize
;
normalize
[
"axis"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"layernorm"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
==
2
)
{
if
(
inputs
.
at
(
1
).
lens
().
front
()
!=
inputs
.
front
().
lens
().
at
(
axis
))
MIGRAPHX_THROW
(
"LAYERNORM: weights have wrong shape"
);
}
if
(
inputs
.
size
()
==
3
)
{
if
(
inputs
.
at
(
2
).
lens
().
front
()
!=
inputs
.
front
().
lens
().
at
(
axis
))
MIGRAPHX_THROW
(
"LAYERNORM: bias has wrong shape"
);
}
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
x_lens
=
args
.
front
().
get_shape
().
lens
();
auto
norm_count
=
std
::
accumulate
(
x_lens
.
begin
(),
x_lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
auto
norm_size
=
std
::
accumulate
(
x_lens
.
begin
()
+
axis
,
x_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
args
.
size
()
==
3
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
],
args
[
2
])(
[
&
](
auto
output
,
auto
data
,
auto
weights
,
auto
bias
)
{
par_for
(
norm_count
,
[
&
](
auto
idx
)
{
auto
offset
=
idx
*
norm_size
;
double
mean
=
0
;
double
mean_square
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
mean
+=
data
[
offset
+
i
];
mean_square
+=
data
[
offset
+
i
]
*
data
[
offset
+
i
];
}
mean
/=
norm_size
;
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
if
(
args
.
size
()
==
3
)
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
*
weights
[
i
]
+
bias
[
i
];
else
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
*
weights
[
i
];
}
});
});
}
else
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
data
)
{
par_for
(
norm_count
,
[
&
](
auto
idx
)
{
auto
offset
=
idx
*
norm_size
;
double
mean
=
0
;
double
mean_square
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
mean
+=
data
[
offset
+
i
];
mean_square
+=
data
[
offset
+
i
]
*
data
[
offset
+
i
];
}
mean
/=
norm_size
;
mean_square
=
sqrt
(
mean_square
/
norm_size
-
mean
*
mean
+
epsilon
);
for
(
std
::
size_t
i
=
0
;
i
<
norm_size
;
++
i
)
{
output
[
offset
+
i
]
=
(
data
[
offset
+
i
]
-
mean
)
/
mean_square
;
// scale and bias handled by pointwise ops
}
});
});
}
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
375c7b8d
...
@@ -43,7 +43,6 @@
...
@@ -43,7 +43,6 @@
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/isnan.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/load.hpp>
...
...
src/onnx/parse_layernorm.cpp
View file @
375c7b8d
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
...
src/targets/gpu/CMakeLists.txt
View file @
375c7b8d
...
@@ -148,7 +148,6 @@ add_library(migraphx_gpu
...
@@ -148,7 +148,6 @@ add_library(migraphx_gpu
int8_conv_pack.cpp
int8_conv_pack.cpp
int8_gemm_pack.cpp
int8_gemm_pack.cpp
kernel.cpp
kernel.cpp
layernorm.cpp
lowering.cpp
lowering.cpp
logsoftmax.cpp
logsoftmax.cpp
loop.cpp
loop.cpp
...
@@ -205,7 +204,6 @@ register_migraphx_gpu_ops(hip_
...
@@ -205,7 +204,6 @@ register_migraphx_gpu_ops(hip_
floor
floor
gather
gather
greater
greater
layernorm
less
less
log
log
logsoftmax
logsoftmax
...
...
src/targets/gpu/include/migraphx/gpu/device/layernorm.hpp
deleted
100644 → 0
View file @
276dda76
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LAYERNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LAYERNORM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
);
void
triadd_layernorm
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/layernorm.hpp
deleted
100644 → 0
View file @
276dda76
#ifndef MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#include <migraphx/op/layernorm.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/argument.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_layernorm
{
op
::
layernorm
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::layernorm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
void
finalize
(
context
&
,
const
shape
&
,
const
std
::
vector
<
shape
>&
);
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/layernorm.cpp
deleted
100644 → 0
View file @
276dda76
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_layernorm
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
hip_layernorm
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
layernorm
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
[
0
]);
return
args
.
back
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/lowering.cpp
View file @
375c7b8d
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include <migraphx/op/dot.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment