Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
76f68df4
Commit
76f68df4
authored
Jul 19, 2018
by
wsttiger
Browse files
Merged from master
parents
dc0c4810
8ae3ffea
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
302 additions
and
51 deletions
+302
-51
src/targets/miopen/include/migraph/miopen/hip.hpp
src/targets/miopen/include/migraph/miopen/hip.hpp
+14
-0
src/targets/miopen/include/migraph/miopen/lowering.hpp
src/targets/miopen/include/migraph/miopen/lowering.hpp
+19
-0
src/targets/miopen/include/migraph/miopen/miopen.hpp
src/targets/miopen/include/migraph/miopen/miopen.hpp
+1
-1
src/targets/miopen/include/migraph/miopen/rocblas.hpp
src/targets/miopen/include/migraph/miopen/rocblas.hpp
+19
-0
src/targets/miopen/include/migraph/miopen/target.hpp
src/targets/miopen/include/migraph/miopen/target.hpp
+3
-3
src/targets/miopen/include/migraph/miopen/write_literals.hpp
src/targets/miopen/include/migraph/miopen/write_literals.hpp
+21
-0
src/targets/miopen/lowering.cpp
src/targets/miopen/lowering.cpp
+35
-41
src/targets/miopen/rocblas.cpp
src/targets/miopen/rocblas.cpp
+15
-0
src/targets/miopen/target.cpp
src/targets/miopen/target.cpp
+24
-0
src/targets/miopen/write_literals.cpp
src/targets/miopen/write_literals.cpp
+25
-0
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+78
-0
test/miopen/miopen.cpp
test/miopen/miopen.cpp
+19
-2
tools/include/context.hpp
tools/include/context.hpp
+6
-0
tools/include/operation.hpp
tools/include/operation.hpp
+8
-1
tools/te.py
tools/te.py
+15
-3
No files found.
src/targets/miopen/include/migraph/miopen/hip.hpp
View file @
76f68df4
...
...
@@ -26,6 +26,20 @@ struct hip_allocate
}
};
struct
hip_write
{
std
::
string
name
()
const
{
return
"hip::write"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
front
();
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
args
)
const
{
return
to_gpu
(
args
.
front
());
}
};
}
// namespace miopen
}
// namespace migraph
...
...
src/targets/miopen/include/migraph/miopen/lowering.hpp
0 → 100644
View file @
76f68df4
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#include <migraph/program.hpp>
namespace
migraph
{
namespace
miopen
{
struct
lowering
{
std
::
string
name
()
const
{
return
"miopen::lowering"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace miopen
}
// namespace migraph
#endif
src/targets/miopen/include/migraph/miopen/miopen.hpp
View file @
76f68df4
...
...
@@ -2,7 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <miopen/miopen.h>
namespace
migraph
{
...
...
src/targets/miopen/include/migraph/miopen/rocblas.hpp
0 → 100644
View file @
76f68df4
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <rocblas.h>
namespace
migraph
{
namespace
miopen
{
using
rocblas_handle_ptr
=
MIGRAPH_MANAGE_PTR
(
rocblas_handle
,
rocblas_destroy_handle
);
rocblas_handle_ptr
create_rocblas_handle_ptr
();
}
// namespace miopen
}
// namespace migraph
#endif
src/targets/miopen/include/migraph/miopen/
miopen_
target.hpp
→
src/targets/miopen/include/migraph/miopen/target.hpp
View file @
76f68df4
...
...
@@ -6,11 +6,11 @@
namespace
migraph
{
namespace
miopen
{
struct
miopen_
target
struct
target
{
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
;
context
get_context
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraph
::
context
&
ctx
)
const
;
migraph
::
context
get_context
()
const
;
};
}
// namespace miopen
...
...
src/targets/miopen/include/migraph/miopen/write_literals.hpp
0 → 100644
View file @
76f68df4
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#include <migraph/program.hpp>
namespace
migraph
{
namespace
miopen
{
struct
write_literals
{
std
::
string
name
()
const
{
return
"miopen::write_literals"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace miopen
}
// namespace migraph
#endif
src/targets/miopen/
miopen_target
.cpp
→
src/targets/miopen/
lowering
.cpp
View file @
76f68df4
#include <migraph/miopen/miopen_target.hpp>
#include <rocblas.h>
#include <migraph/miopen/lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
...
...
@@ -7,15 +8,13 @@
#include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/miopen/kernels.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace
migraph
{
namespace
miopen
{
struct
miopen_context
{
shared
<
miopen_handle
>
handle
;
};
struct
miopen_convolution
{
convolution
op
;
...
...
@@ -27,9 +26,8 @@ struct miopen_convolution
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -79,9 +77,8 @@ struct miopen_pooling
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -112,7 +109,7 @@ struct miopen_add
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
...
...
@@ -129,7 +126,6 @@ struct miopen_add
}
else
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
...
@@ -159,18 +155,31 @@ struct miopen_gemm
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
from_gpu
(
args
[
0
]),
from_gpu
(
args
[
1
]))(
[
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
dfor
(
input1
.
get_shape
().
lens
()[
0
],
input2
.
get_shape
().
lens
()[
1
],
input2
.
get_shape
().
lens
()[
0
])(
[
&
](
auto
i
,
auto
j
,
auto
k
)
{
output
(
i
,
j
)
+=
input1
(
i
,
k
)
*
input2
(
k
,
j
);
});
});
return
to_gpu
(
result
);
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
lens
()[
1
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
lens
()[
1
];
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_sgemm
(
ctx
.
rbhandle
.
get
(),
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha
,
args
[
1
].
implicit
(),
ldb
,
args
[
0
].
implicit
(),
lda
,
&
beta
,
args
[
2
].
implicit
(),
ldc
);
return
args
[
2
];
}
};
...
...
@@ -216,9 +225,8 @@ struct miopen_relu
return
inputs
.
at
(
1
);
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -241,7 +249,7 @@ struct miopen_apply
void
apply
()
{
prog
->
insert_instruction
(
prog
->
begin
(),
check_context
<
miopen_
context
>
{});
prog
->
insert_instruction
(
prog
->
begin
(),
check_context
<
context
>
{});
for
(
auto
it
=
prog
->
begin
();
it
!=
prog
->
end
();
it
++
)
{
if
(
it
->
op
.
name
()
==
"convolution"
)
...
...
@@ -354,21 +362,7 @@ struct miopen_apply
}
};
struct
miopen_pass
{
std
::
string
name
()
const
{
return
"miopen::pass"
;
}
void
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
};
std
::
vector
<
pass
>
miopen_target
::
get_passes
(
context
&
)
const
{
return
{
miopen_pass
{}};
}
std
::
string
miopen_target
::
name
()
const
{
return
"miopen"
;
}
context
miopen_target
::
get_context
()
const
{
return
miopen_context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
))};
}
void
lowering
::
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
}
// namespace miopen
...
...
src/targets/miopen/rocblas.cpp
0 → 100644
View file @
76f68df4
#include <migraph/miopen/rocblas.hpp>
namespace
migraph
{
namespace
miopen
{
rocblas_handle_ptr
create_rocblas_handle_ptr
()
{
rocblas_handle
handle
;
rocblas_create_handle
(
&
handle
);
return
rocblas_handle_ptr
{
handle
};
}
}
// namespace miopen
}
// namespace migraph
src/targets/miopen/target.cpp
0 → 100644
View file @
76f68df4
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/lowering.hpp>
#include <migraph/miopen/write_literals.hpp>
#include <migraph/miopen/context.hpp>
namespace
migraph
{
namespace
miopen
{
std
::
vector
<
pass
>
target
::
get_passes
(
migraph
::
context
&
)
const
{
return
{
lowering
{},
write_literals
{}};
}
std
::
string
target
::
name
()
const
{
return
"miopen"
;
}
migraph
::
context
target
::
get_context
()
const
{
return
context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
)),
share
(
create_rocblas_handle_ptr
())};
}
}
// namespace miopen
}
// namespace migraph
src/targets/miopen/write_literals.cpp
0 → 100644
View file @
76f68df4
#include <migraph/miopen/write_literals.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/instruction.hpp>
namespace
migraph
{
namespace
miopen
{
void
write_literals
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
op
.
name
()
==
"@literal"
)
{
literal
l
=
ins
->
lit
;
auto
pre
=
p
.
add_literal
(
l
);
p
.
replace_instruction
(
ins
,
hip_write
{},
pre
);
}
}
}
}
// namespace miopen
}
// namespace migraph
test/cpu_ops_test.cpp
View file @
76f68df4
...
...
@@ -6,6 +6,25 @@
#include "test.hpp"
#include "verify.hpp"
void
batch_norm_inference_test
()
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
}};
auto
x
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
1
,
2
,
3
,
4
}});
auto
gamma
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
1
}});
auto
beta
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
0
}});
auto
mean
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
0
}});
auto
variance
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
1
}});
p
.
add_instruction
(
migraph
::
batch_norm_inference
{},
x
,
mean
,
variance
,
gamma
,
beta
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
result_vector
(
4
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1
/
(
1
+
1.0e-6
),
2
/
(
1
+
1.0e-6
),
3
/
(
1
+
1.0e-6
),
4
/
(
1
+
1.0e-6
)};
EXPECT
(
test
::
verify_range
(
result_vector
,
gold
));
}
void
exp_test
()
{
migraph
::
program
p
;
...
...
@@ -252,6 +271,63 @@ void gemm_test()
}
}
void
maxpool_test
()
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
-
2.1314404
,
-
1.63041711
,
1.54562736
,
1.04625261
,
-
1.42931843
,
-
0.48703974
,
0.4065806
,
-
0.1524526
,
1.30775225
,
0.45538983
,
-
0.06631992
,
-
1.75332725
,
1.33493888
,
0.47327688
,
0.36873096
,
1.18358743
,
-
0.34640595
,
1.22098756
,
0.01946825
,
-
0.20238149
,
0.43348005
,
-
0.67991608
,
-
0.83041084
,
0.93537551
,
0.70241445
,
-
0.5654031
,
-
1.30899191
,
-
0.26735824
,
-
0.52444768
,
1.99097753
,
1.86504853
,
-
0.26506025
,
0.26236168
,
0.43763575
,
0.95300823
,
-
1.02733946
,
-
0.74655169
,
-
0.5374338
,
-
0.28901565
,
-
0.59789604
,
0.5310151
,
0.99125904
,
0.40609556
,
-
1.57175648
,
0.22031412
,
1.45862222
,
0.53217483
,
1.39087725
,
1.00170159
,
-
0.87175864
,
-
1.7204628
,
-
1.72008383
,
-
0.38656762
,
-
0.01443311
,
1.46645272
,
-
1.39995027
,
0.22505587
,
-
0.43461126
,
-
0.05511411
,
-
0.79950953
,
-
0.01439556
,
0.08795211
,
1.18943918
,
-
0.84079367
,
-
1.73383629
,
-
0.55662078
,
-
0.30626822
,
-
0.67339015
,
0.44179603
,
0.54316711
,
0.40899998
,
-
0.27831686
,
-
1.11900508
,
-
0.0881724
,
0.35483059
,
2.36277103
,
-
0.04765317
,
-
0.36865309
,
0.73814237
,
1.47151589
,
1.36546791
,
-
0.32649881
,
-
1.0517807
,
2.24768877
,
0.68883753
,
0.58646208
,
-
0.91017133
,
-
0.50462508
,
-
0.4013325
,
-
0.72348958
,
-
0.47368807
,
0.35285577
,
-
1.01817429
,
-
0.5152272
,
0.60321307
,
0.43521205
,
-
0.23733577
,
0.66427642
,
0.82949388
,
0.82443929
,
0.71550399
,
0.34561086
,
0.68570769
,
-
0.40718508
,
-
1.20350206
,
0.15793853
,
-
2.31013632
,
-
0.07934658
,
-
0.09348056
,
0.36576006
,
2.46601582
,
0.11090943
,
0.9144392
,
0.56759721
,
-
0.22112127
,
-
0.21955389
,
0.72474903
,
-
1.28448462
,
1.53285873
,
0.37437943
,
0.31409341
,
1.95433736
,
0.91620457
,
0.86205518
,
1.24365854
,
0.19248386
,
0.22526583
,
0.13462132
,
-
0.27561715
,
-
2.06446075
,
-
0.02306402
,
-
1.38278747
,
1.1411345
,
1.31293464
,
-
1.86041689
,
1.06763375
,
-
0.26541466
,
1.4545635
,
1.11430049
,
-
0.66491818
,
0.87101674
,
0.67768967
,
-
1.02062869
,
-
1.05031872
,
-
2.2764678
,
-
2.0200038
,
0.37592548
,
-
0.26701379
,
-
0.83388507
,
0.19403623
,
1.00968623
,
0.11020003
,
1.16736257
,
-
1.1160326
,
0.47346735
,
0.6126079
,
-
0.19135755
,
1.33624589
,
-
0.29802522
,
-
0.57873946
,
-
1.06555879
,
-
0.20686582
,
1.36892557
,
-
0.19937795
,
0.8649236
,
-
1.40126073
,
1.53441942
,
0.34682792
,
-
1.31724346
,
-
1.32898355
,
2.40126371
,
0.07845283
,
1.35732043
,
-
0.63678312
,
0.39429256
,
-
1.36487007
,
-
0.31026676
,
-
0.44981545
,
-
0.28994772
,
-
0.14657612
,
-
1.75206447
,
-
0.70612341
,
1.20071781
,
-
1.64647579
,
-
0.7133292
,
0.88494766
,
0.52119428
,
-
2.77387547
,
2.07681108
,
-
0.90133125
,
0.2847338
,
0.6174528
,
-
0.20616426
,
-
0.64263535
,
-
1.08496261
,
0.54275119
,
-
0.88503587
,
0.6629802
,
1.47319221
,
-
1.05829155
,
-
0.97027361
,
-
0.93187737
,
-
1.39954746
,
-
0.52359426
,
-
0.14743951
,
1.51522756
,
0.2078452
,
-
1.28156149
,
-
1.19363916
,
-
0.78680223
,
-
0.89094824
,
1.30212069
,
-
0.77974445
,
-
0.58411664
,
0.48764706
,
-
0.67132682
};
std
::
vector
<
float
>
c
=
{
1.33493888
,
1.54562736
,
1.22098756
,
1.33493888
,
1.18358743
,
1.99097753
,
1.00170159
,
1.45862222
,
1.39087725
,
1.46645272
,
1.18943918
,
-
0.01443311
,
1.47151589
,
2.36277103
,
2.24768877
,
0.68883753
,
0.82949388
,
0.71550399
,
1.95433736
,
2.46601582
,
1.53285873
,
1.95433736
,
1.06763375
,
1.4545635
,
1.33624589
,
1.16736257
,
0.6126079
,
1.36892557
,
2.40126371
,
1.53441942
,
0.52119428
,
2.07681108
,
0.88494766
,
1.51522756
,
0.54275119
,
0.6629802
};
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
6
,
6
}};
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
3
,
2
}}},
al
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
cout
<<
result
.
get_shape
()
<<
std
::
endl
;
std
::
vector
<
float
>
results_vector
(
36
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
float
tol
=
1e-6
;
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
{
// std::cout << results_vector[i] << " " << c[i] << std::endl;
EXPECT
(
std
::
abs
(
results_vector
[
i
]
-
c
[
i
])
<
tol
);
}
}
void
softmax_test
()
{
migraph
::
program
p
;
...
...
@@ -564,7 +640,9 @@ int main()
transpose_test
();
contiguous_test
();
softmax_test
();
// maxpool_test();
conv2d_test
();
conv2d_padding_test
();
conv2d_padding_stride_test
();
batch_norm_inference_test
();
}
test/miopen/miopen.cpp
View file @
76f68df4
...
...
@@ -3,7 +3,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/miopen/
miopen_
target.hpp>
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/manage_ptr.hpp>
...
...
@@ -27,7 +27,7 @@ migraph::argument run_gpu()
{
V
v
;
auto
p
=
v
.
create_program
();
p
.
compile
(
migraph
::
miopen
::
miopen_
target
{});
p
.
compile
(
migraph
::
miopen
::
target
{});
auto
m
=
v
.
create_params
();
for
(
auto
&&
e
:
m
)
...
...
@@ -49,6 +49,23 @@ void verify_program()
visit_all
(
cpu_arg
,
gpu_arg
)([](
auto
cpu
,
auto
gpu
)
{
EXPECT
(
test
::
verify_range
(
cpu
,
gpu
));
});
}
struct
test_literals
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
input
=
p
.
add_literal
(
generate_literal
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}}));
auto
weights
=
p
.
add_literal
(
generate_literal
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}}));
auto
conv
=
p
.
add_instruction
(
migraph
::
convolution
{},
input
,
weights
);
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
conv
);
return
p
;
}
migraph
::
program
::
parameter_map
create_params
()
const
{
return
{};
}
};
struct
test_add
{
migraph
::
program
create_program
()
const
...
...
tools/include/context.hpp
View file @
76f68df4
#ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
<%
...
...
tools/include/operation.hpp
View file @
76f68df4
...
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
...
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
<%
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
%>
...
...
tools/te.py
View file @
76f68df4
...
...
@@ -213,16 +213,21 @@ def internal_name(name):
else
:
return
name
def
generate_call
(
m
,
friend
):
def
generate_call
(
m
,
friend
,
indirect
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
if
','
in
args
:
return
args
.
replace
(
','
,
op
)
else
:
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
args
=
args
)
return
string
.
Template
(
'${op}${arg
s
}'
).
substitute
(
op
=
op
,
args
=
args
)
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
if
indirect
:
if
m
[
'args'
]:
return
string
.
Template
(
'${default}(private_detail_te_value, ${args})'
).
substitute
(
m
)
else
:
return
string
.
Template
(
'${default}(private_detail_te_value)'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
...
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params
=
[]
skip
=
False
friend
=
False
indirect
=
False
if
'friend'
in
d
[
name
]:
friend
=
True
skip
=
True
if
'default'
in
d
[
name
]:
indirect
=
True
for
x
in
d
[
name
]:
t
=
d
[
name
][
x
]
if
x
==
'return'
:
...
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
member
[
'default'
]
=
t
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
.
startswith
(
'__'
)
and
x
.
endswith
(
'__'
):
continue
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
...
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'call'
]
=
generate_call
(
member
,
friend
)
member
[
'call'
]
=
generate_call
(
member
,
friend
,
indirect
)
return
member
return
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment