Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d1481b13
"configs/datasets/vscode:/vscode.git/clone" did not exist on "e9cdb24dddde3419ae490a649cff9d79947ed45c"
Commit
d1481b13
authored
Aug 02, 2018
by
Paul
Browse files
Merge branch 'contigous-pass'
parents
b4d2a740
0df528ee
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
227 additions
and
86 deletions
+227
-86
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+22
-0
src/include/migraph/auto_contiguous.hpp
src/include/migraph/auto_contiguous.hpp
+19
-0
src/include/migraph/literal.hpp
src/include/migraph/literal.hpp
+24
-6
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+4
-4
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+2
-0
src/include/migraph/shape.hpp
src/include/migraph/shape.hpp
+3
-1
src/include/migraph/tensor_view.hpp
src/include/migraph/tensor_view.hpp
+5
-5
src/program.cpp
src/program.cpp
+2
-0
src/shape.cpp
src/shape.cpp
+20
-15
src/targets/cpu/cpu_lowering.cpp
src/targets/cpu/cpu_lowering.cpp
+0
-26
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-27
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+58
-0
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+1
-1
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+19
-0
test/shape_test.cpp
test/shape_test.cpp
+47
-1
No files found.
src/CMakeLists.txt
View file @
d1481b13
add_library
(
migraph
add_library
(
migraph
auto_contiguous.cpp
dead_code_elimination.cpp
dead_code_elimination.cpp
generate.cpp
generate.cpp
program.cpp
program.cpp
...
...
src/auto_contiguous.cpp
0 → 100644
View file @
d1481b13
#include <migraph/auto_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
namespace
migraph
{
void
auto_contiguous
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
shape
s
=
ins
->
result
;
if
(
not
s
.
standard
())
{
auto
prev
=
p
.
insert_instruction
(
ins
,
ins
->
op
,
ins
->
arguments
);
p
.
replace_instruction
(
ins
,
contiguous
{},
prev
);
}
}
}
}
// namespace migraph
src/include/migraph/auto_contiguous.hpp
0 → 100644
View file @
d1481b13
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
auto_contiguous
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/migraph/literal.hpp
View file @
d1481b13
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
#include <migraph/raw_data.hpp>
...
@@ -26,24 +27,21 @@ struct literal : raw_data<literal>
...
@@ -26,24 +27,21 @@ struct literal : raw_data<literal>
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
{
{
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
}
);
fill
(
x
.
begin
(),
x
.
end
()
);
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
{
{
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
}
);
fill
(
x
.
begin
(),
x
.
end
()
);
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
m_shape
(
s
)
{
{
assert
(
s
.
packed
());
fill
(
start
,
end
);
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
m_shape
(
s
)
{}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
m_shape
(
s
)
{}
...
@@ -66,6 +64,26 @@ struct literal : raw_data<literal>
...
@@ -66,6 +64,26 @@ struct literal : raw_data<literal>
private:
private:
std
::
vector
<
char
>
buffer
;
std
::
vector
<
char
>
buffer
;
shape
m_shape
;
shape
m_shape
;
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
{
if
(
m_shape
.
standard
())
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
else
{
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
data
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
it
++
;
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
});
});
}
}
};
};
}
// namespace migraph
}
// namespace migraph
...
...
src/include/migraph/operators.hpp
View file @
d1481b13
...
@@ -232,9 +232,9 @@ struct transpose
...
@@ -232,9 +232,9 @@ struct transpose
}
}
return
{
t
,
output_lens
,
output_strides
};
return
{
t
,
output_lens
,
output_strides
};
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
MIGRAPH_THROW
(
"not computable"
)
;
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)}
;
}
}
};
};
...
@@ -297,9 +297,9 @@ struct reshape
...
@@ -297,9 +297,9 @@ struct reshape
return
s
;
return
s
;
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
MIGRAPH_THROW
(
"not computable"
)
;
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)}
;
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
reshape
&
op
)
...
...
src/include/migraph/program.hpp
View file @
d1481b13
...
@@ -78,6 +78,8 @@ struct program
...
@@ -78,6 +78,8 @@ struct program
instruction_ref
begin
();
instruction_ref
begin
();
instruction_ref
end
();
instruction_ref
end
();
shape
get_shape
()
const
;
instruction_ref
validate
()
const
;
instruction_ref
validate
()
const
;
void
compile
(
const
target
&
t
);
void
compile
(
const
target
&
t
);
...
...
src/include/migraph/shape.hpp
View file @
d1481b13
...
@@ -76,7 +76,9 @@ struct shape
...
@@ -76,7 +76,9 @@ struct shape
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
size_t
index
(
std
::
size_t
i
)
const
;
bool
packed
()
const
;
bool
packed
()
const
;
bool
transposed
()
const
;
bool
broadcasted
()
const
;
bool
broadcasted
()
const
;
bool
standard
()
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
...
@@ -139,7 +141,7 @@ struct shape
...
@@ -139,7 +141,7 @@ struct shape
type_t
m_type
;
type_t
m_type
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_lens
;
std
::
vector
<
std
::
size_t
>
m_strides
;
std
::
vector
<
std
::
size_t
>
m_strides
;
bool
m_
packe
d
;
bool
m_
standar
d
;
void
calculate_strides
();
void
calculate_strides
();
std
::
size_t
element_space
()
const
;
std
::
size_t
element_space
()
const
;
...
...
src/include/migraph/tensor_view.hpp
View file @
d1481b13
...
@@ -88,16 +88,16 @@ struct tensor_view
...
@@ -88,16 +88,16 @@ struct tensor_view
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
return
m_data
[
m_shape
.
index
(
this
->
size
()
-
1
)];
}
}
// TODO: Add iterators so it can handle non
packe
d tensors
// TODO: Add iterators so it can handle non
standar
d tensors
T
*
begin
()
T
*
begin
()
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
return
m_data
;
return
m_data
;
}
}
T
*
end
()
T
*
end
()
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
m_data
;
return
m_data
;
else
else
...
@@ -106,13 +106,13 @@ struct tensor_view
...
@@ -106,13 +106,13 @@ struct tensor_view
const
T
*
begin
()
const
const
T
*
begin
()
const
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
return
m_data
;
return
m_data
;
}
}
const
T
*
end
()
const
const
T
*
end
()
const
{
{
assert
(
this
->
m_shape
.
packe
d
());
assert
(
this
->
m_shape
.
standar
d
());
if
(
this
->
empty
())
if
(
this
->
empty
())
return
m_data
;
return
m_data
;
else
else
...
...
src/program.cpp
View file @
d1481b13
...
@@ -126,6 +126,8 @@ bool program::has_instruction(instruction_ref ins) const
...
@@ -126,6 +126,8 @@ bool program::has_instruction(instruction_ref ins) const
instruction_ref
program
::
begin
()
{
return
impl
->
instructions
.
begin
();
}
instruction_ref
program
::
begin
()
{
return
impl
->
instructions
.
begin
();
}
instruction_ref
program
::
end
()
{
return
impl
->
instructions
.
end
();
}
instruction_ref
program
::
end
()
{
return
impl
->
instructions
.
end
();
}
shape
program
::
get_shape
()
const
{
return
impl
->
instructions
.
back
().
result
;
}
instruction_ref
program
::
validate
()
const
instruction_ref
program
::
validate
()
const
{
{
return
std
::
find_if
(
impl
->
instructions
.
begin
(),
return
std
::
find_if
(
impl
->
instructions
.
begin
(),
...
...
src/shape.cpp
View file @
d1481b13
...
@@ -8,10 +8,11 @@
...
@@ -8,10 +8,11 @@
namespace
migraph
{
namespace
migraph
{
shape
::
shape
()
:
m_type
(
float_type
),
m_
packe
d
(
false
)
{}
shape
::
shape
()
:
m_type
(
float_type
),
m_
standar
d
(
false
)
{}
shape
::
shape
(
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_packed
(
true
)
{}
shape
::
shape
(
type_t
t
)
:
m_type
(
t
),
m_lens
({
1
}),
m_strides
({
1
}),
m_standard
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_packed
(
true
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
m_type
(
t
),
m_lens
(
std
::
move
(
l
)),
m_standard
(
true
)
{
{
this
->
calculate_strides
();
this
->
calculate_strides
();
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
...
@@ -22,7 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
...
@@ -22,7 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
m_lens
.
size
()
==
m_strides
.
size
());
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
assert
(
std
::
any_of
(
m_strides
.
begin
(),
m_strides
.
end
(),
[](
auto
x
)
{
return
x
>
0
;
})
and
"At least one stride must be non-zero"
);
"At least one stride must be non-zero"
);
m_
packe
d
=
this
->
elements
()
==
this
->
element_space
();
m_
standar
d
=
this
->
packed
()
and
not
this
->
transposed
();
}
}
void
shape
::
calculate_strides
()
void
shape
::
calculate_strides
()
...
@@ -66,7 +67,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
...
@@ -66,7 +67,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
if
(
this
->
packe
d
())
if
(
this
->
standar
d
())
return
i
;
return
i
;
else
else
return
std
::
inner_product
(
this
->
lens
().
begin
(),
return
std
::
inner_product
(
this
->
lens
().
begin
(),
...
@@ -79,7 +80,12 @@ std::size_t shape::index(std::size_t i) const
...
@@ -79,7 +80,12 @@ std::size_t shape::index(std::size_t i) const
return
((
i
/
stride
)
%
len
)
*
stride
;
return
((
i
/
stride
)
%
len
)
*
stride
;
});
});
}
}
bool
shape
::
packed
()
const
{
return
this
->
m_packed
;
}
bool
shape
::
packed
()
const
{
return
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
{
return
not
std
::
is_sorted
(
this
->
strides
().
rbegin
(),
this
->
strides
().
rend
());
}
bool
shape
::
broadcasted
()
const
bool
shape
::
broadcasted
()
const
{
{
...
@@ -90,18 +96,17 @@ bool shape::broadcasted() const
...
@@ -90,18 +96,17 @@ bool shape::broadcasted() const
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
std
::
multiplies
<
std
::
size_t
>
())
==
0
;
}
}
bool
shape
::
standard
()
const
{
return
this
->
m_standard
;
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
// TODO: Get rid of intermediate vector
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
std
::
vector
<
std
::
size_t
>
max_indices
(
this
->
lens
().
size
());
return
std
::
inner_product
(
this
->
lens
().
begin
(),
std
::
transform
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
vector
<
std
::
size_t
>
(
this
->
lens
().
size
(),
1
).
begin
(),
std
::
size_t
{
0
},
max_indices
.
begin
(),
std
::
plus
<
std
::
size_t
>
{},
std
::
minus
<
std
::
size_t
>
());
[](
std
::
size_t
l
,
std
::
size_t
s
)
{
return
(
l
-
1
)
*
s
;
})
+
return
std
::
inner_product
(
max_indices
.
begin
(),
max_indices
.
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
})
+
1
;
1
;
}
}
...
...
src/targets/cpu/cpu_lowering.cpp
View file @
d1481b13
...
@@ -203,18 +203,6 @@ struct cpu_pooling
...
@@ -203,18 +203,6 @@ struct cpu_pooling
}
}
};
};
struct
cpu_transpose
{
transpose
op
;
std
::
string
name
()
const
{
return
"cpu::transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
};
struct
cpu_contiguous
struct
cpu_contiguous
{
{
contiguous
op
;
contiguous
op
;
...
@@ -232,18 +220,6 @@ struct cpu_contiguous
...
@@ -232,18 +220,6 @@ struct cpu_contiguous
}
}
};
};
struct
cpu_reshape
{
reshape
op
;
std
::
string
name
()
const
{
return
"cpu::reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
};
struct
cpu_gemm
struct
cpu_gemm
{
{
gemm
op
;
gemm
op
;
...
@@ -545,9 +521,7 @@ struct cpu_apply
...
@@ -545,9 +521,7 @@ struct cpu_apply
apply_map
[
"gemm"
]
=
extend_op
<
cpu_gemm
,
gemm
>
();
apply_map
[
"gemm"
]
=
extend_op
<
cpu_gemm
,
gemm
>
();
apply_map
[
"batch_norm_inference"
]
=
apply_map
[
"batch_norm_inference"
]
=
extend_op
<
cpu_batch_norm_inference
,
batch_norm_inference
>
();
extend_op
<
cpu_batch_norm_inference
,
batch_norm_inference
>
();
apply_map
[
"reshape"
]
=
extend_op
<
cpu_reshape
,
reshape
>
();
apply_map
[
"contiguous"
]
=
extend_op
<
cpu_contiguous
,
contiguous
>
();
apply_map
[
"contiguous"
]
=
extend_op
<
cpu_contiguous
,
contiguous
>
();
apply_map
[
"transpose"
]
=
extend_op
<
cpu_transpose
,
transpose
>
();
apply_map
[
"identity"
]
=
simple_op
<
cpu_unary
<
identity_op
>>
();
apply_map
[
"identity"
]
=
simple_op
<
cpu_unary
<
identity_op
>>
();
apply_map
[
"tanh"
]
=
simple_op
<
cpu_unary
<
tanh_op
>>
();
apply_map
[
"tanh"
]
=
simple_op
<
cpu_unary
<
tanh_op
>>
();
...
...
src/targets/gpu/lowering.cpp
View file @
d1481b13
...
@@ -183,22 +183,6 @@ struct miopen_gemm
...
@@ -183,22 +183,6 @@ struct miopen_gemm
}
}
};
};
struct
miopen_transpose
{
transpose
op
;
std
::
string
name
()
const
{
return
"gpu::transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
};
struct
miopen_contiguous
struct
miopen_contiguous
{
{
contiguous
op
;
contiguous
op
;
...
@@ -271,10 +255,6 @@ struct miopen_apply
...
@@ -271,10 +255,6 @@ struct miopen_apply
{
{
apply_gemm
(
it
);
apply_gemm
(
it
);
}
}
else
if
(
it
->
op
.
name
()
==
"transpose"
)
{
apply_transpose
(
it
);
}
else
if
(
it
->
op
.
name
()
==
"contiguous"
)
else
if
(
it
->
op
.
name
()
==
"contiguous"
)
{
{
apply_contiguous
(
it
);
apply_contiguous
(
it
);
...
@@ -346,13 +326,6 @@ struct miopen_apply
...
@@ -346,13 +326,6 @@ struct miopen_apply
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
}
}
void
apply_transpose
(
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
transpose
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_transpose
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
}
void
apply_contiguous
(
instruction_ref
ins
)
void
apply_contiguous
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
...
...
test/auto_contiguous_test.cpp
0 → 100644
View file @
d1481b13
#include <migraph/auto_contiguous.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct
contiguous_target
{
std
::
string
name
()
const
{
return
"contiguous"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
return
{
migraph
::
auto_contiguous
{}};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
migraph
::
literal
get_2x2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
,
2
}},
{
1
,
2
,
3
,
4
}};
}
migraph
::
literal
get_2
()
{
return
migraph
::
literal
{{
migraph
::
shape
::
float_type
,
{
2
}},
{
1
,
2
}};
}
void
after_literal_transpose
()
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
auto
t
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
l
);
p
.
add_instruction
(
pass_op
{},
t
);
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
transposed
());
p
.
compile
(
contiguous_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
transposed
());
}
void
after_literal_broadcast
()
{
migraph
::
program
p
;
auto
l1
=
p
.
add_literal
(
get_2x2
());
auto
l2
=
p
.
add_literal
(
get_2
());
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
broadcasted
());
auto
b
=
p
.
add_instruction
(
migraph
::
broadcast
{},
l1
,
l2
);
p
.
add_instruction
(
pass_op
{},
b
);
EXPECT
(
not
p
.
get_shape
().
standard
());
EXPECT
(
p
.
get_shape
().
broadcasted
());
p
.
compile
(
contiguous_target
{});
EXPECT
(
p
.
get_shape
().
standard
());
EXPECT
(
not
p
.
get_shape
().
broadcasted
());
}
int
main
()
{
after_literal_transpose
();
after_literal_broadcast
();
}
test/cpu_ops_test.cpp
View file @
d1481b13
...
@@ -641,7 +641,7 @@ void contiguous_test()
...
@@ -641,7 +641,7 @@ void contiguous_test()
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
size_t
>
new_lens
=
{
1
,
3
,
2
,
2
};
std
::
vector
<
size_t
>
new_lens
=
{
1
,
3
,
2
,
2
};
std
::
vector
<
size_t
>
new_strides
=
{
12
,
1
,
6
,
3
};
std
::
vector
<
size_t
>
new_strides
=
{
12
,
1
,
6
,
3
};
std
::
vector
<
float
>
gold
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
std
::
vector
<
float
>
gold
=
{
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
,
3
,
6
,
9
,
0
};
EXPECT
(
test
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
test
::
verify_range
(
results_vector
,
gold
));
}
}
...
...
test/include/basic_ops.hpp
View file @
d1481b13
...
@@ -61,3 +61,22 @@ struct minus_op
...
@@ -61,3 +61,22 @@ struct minus_op
return
inputs
.
front
();
return
inputs
.
front
();
}
}
};
};
struct
pass_op
{
std
::
string
name
()
const
{
return
"pass"
;
}
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
if
(
args
.
empty
())
return
{};
return
args
.
front
();
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
return
{};
return
inputs
.
front
();
}
};
test/shape_test.cpp
View file @
d1481b13
...
@@ -13,6 +13,42 @@ void test_shape_assign()
...
@@ -13,6 +13,42 @@ void test_shape_assign()
EXPECT
(
!
(
s1
!=
s2
));
EXPECT
(
!
(
s1
!=
s2
));
}
}
void
test_shape_packed_default
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
void
test_shape_packed
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
void
test_shape_transposed
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
void
test_shape_broadcasted
()
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
s
.
broadcasted
());
}
void
test_shape_default
()
void
test_shape_default
()
{
{
migraph
::
shape
s1
{};
migraph
::
shape
s1
{};
...
@@ -24,7 +60,10 @@ void test_shape_default()
...
@@ -24,7 +60,10 @@ void test_shape_default()
void
test_shape4
()
void
test_shape4
()
{
{
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
100
,
32
,
8
,
8
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
...
@@ -68,7 +107,10 @@ void test_shape4_nonpacked()
...
@@ -68,7 +107,10 @@ void test_shape4_nonpacked()
std
::
multiplies
<
std
::
size_t
>
());
std
::
multiplies
<
std
::
size_t
>
());
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
lens
,
strides
};
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
lens
,
strides
};
EXPECT
(
!
s
.
packed
());
EXPECT
(
not
s
.
standard
());
EXPECT
(
not
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
type
()
==
migraph
::
shape
::
float_type
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
0
]
==
100
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
EXPECT
(
s
.
lens
()[
1
]
==
32
);
...
@@ -95,6 +137,10 @@ void test_shape4_nonpacked()
...
@@ -95,6 +137,10 @@ void test_shape4_nonpacked()
int
main
()
int
main
()
{
{
test_shape_assign
();
test_shape_assign
();
test_shape_packed_default
();
test_shape_packed
();
test_shape_transposed
();
test_shape_broadcasted
();
test_shape_default
();
test_shape_default
();
test_shape4
();
test_shape4
();
test_shape4_nonpacked
();
test_shape4_nonpacked
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment