Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
76f68df4
"docs/vscode:/vscode.git/clone" did not exist on "bd7cfbd2f852c1a55b83c95163526e04971ebab9"
Commit
76f68df4
authored
Jul 19, 2018
by
wsttiger
Browse files
Merged from master
parents
dc0c4810
8ae3ffea
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
302 additions
and
51 deletions
+302
-51
src/targets/miopen/include/migraph/miopen/hip.hpp
src/targets/miopen/include/migraph/miopen/hip.hpp
+14
-0
src/targets/miopen/include/migraph/miopen/lowering.hpp
src/targets/miopen/include/migraph/miopen/lowering.hpp
+19
-0
src/targets/miopen/include/migraph/miopen/miopen.hpp
src/targets/miopen/include/migraph/miopen/miopen.hpp
+1
-1
src/targets/miopen/include/migraph/miopen/rocblas.hpp
src/targets/miopen/include/migraph/miopen/rocblas.hpp
+19
-0
src/targets/miopen/include/migraph/miopen/target.hpp
src/targets/miopen/include/migraph/miopen/target.hpp
+3
-3
src/targets/miopen/include/migraph/miopen/write_literals.hpp
src/targets/miopen/include/migraph/miopen/write_literals.hpp
+21
-0
src/targets/miopen/lowering.cpp
src/targets/miopen/lowering.cpp
+35
-41
src/targets/miopen/rocblas.cpp
src/targets/miopen/rocblas.cpp
+15
-0
src/targets/miopen/target.cpp
src/targets/miopen/target.cpp
+24
-0
src/targets/miopen/write_literals.cpp
src/targets/miopen/write_literals.cpp
+25
-0
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+78
-0
test/miopen/miopen.cpp
test/miopen/miopen.cpp
+19
-2
tools/include/context.hpp
tools/include/context.hpp
+6
-0
tools/include/operation.hpp
tools/include/operation.hpp
+8
-1
tools/te.py
tools/te.py
+15
-3
No files found.
src/targets/miopen/include/migraph/miopen/hip.hpp
View file @
76f68df4
...
...
@@ -26,6 +26,20 @@ struct hip_allocate
}
};
struct
hip_write
{
std
::
string
name
()
const
{
return
"hip::write"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
front
();
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
args
)
const
{
return
to_gpu
(
args
.
front
());
}
};
}
// namespace miopen
}
// namespace migraph
...
...
src/targets/miopen/include/migraph/miopen/lowering.hpp
0 → 100644
View file @
76f68df4
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#include <migraph/program.hpp>
namespace
migraph
{
namespace
miopen
{
struct
lowering
{
std
::
string
name
()
const
{
return
"miopen::lowering"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace miopen
}
// namespace migraph
#endif
src/targets/miopen/include/migraph/miopen/miopen.hpp
View file @
76f68df4
...
...
@@ -2,7 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <miopen/miopen.h>
namespace
migraph
{
...
...
src/targets/miopen/include/migraph/miopen/rocblas.hpp
0 → 100644
View file @
76f68df4
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <rocblas.h>
namespace
migraph
{
namespace
miopen
{
using
rocblas_handle_ptr
=
MIGRAPH_MANAGE_PTR
(
rocblas_handle
,
rocblas_destroy_handle
);
rocblas_handle_ptr
create_rocblas_handle_ptr
();
}
// namespace miopen
}
// namespace migraph
#endif
src/targets/miopen/include/migraph/miopen/
miopen_
target.hpp
→
src/targets/miopen/include/migraph/miopen/target.hpp
View file @
76f68df4
...
...
@@ -6,11 +6,11 @@
namespace
migraph
{
namespace
miopen
{
struct
miopen_
target
struct
target
{
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
context
&
ctx
)
const
;
context
get_context
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraph
::
context
&
ctx
)
const
;
migraph
::
context
get_context
()
const
;
};
}
// namespace miopen
...
...
src/targets/miopen/include/migraph/miopen/write_literals.hpp
0 → 100644
View file @
76f68df4
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#include <migraph/program.hpp>
namespace
migraph
{
namespace
miopen
{
struct
write_literals
{
std
::
string
name
()
const
{
return
"miopen::write_literals"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace miopen
}
// namespace migraph
#endif
src/targets/miopen/
miopen_target
.cpp
→
src/targets/miopen/
lowering
.cpp
View file @
76f68df4
#include <migraph/miopen/miopen_target.hpp>
#include <rocblas.h>
#include <migraph/miopen/lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
...
...
@@ -7,15 +8,13 @@
#include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/miopen/kernels.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace
migraph
{
namespace
miopen
{
struct
miopen_context
{
shared
<
miopen_handle
>
handle
;
};
struct
miopen_convolution
{
convolution
op
;
...
...
@@ -27,9 +26,8 @@ struct miopen_convolution
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -79,9 +77,8 @@ struct miopen_pooling
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -112,7 +109,7 @@ struct miopen_add
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
...
...
@@ -129,7 +126,6 @@ struct miopen_add
}
else
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
...
@@ -159,18 +155,31 @@ struct miopen_gemm
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
from_gpu
(
args
[
0
]),
from_gpu
(
args
[
1
]))(
[
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
dfor
(
input1
.
get_shape
().
lens
()[
0
],
input2
.
get_shape
().
lens
()[
1
],
input2
.
get_shape
().
lens
()[
0
])(
[
&
](
auto
i
,
auto
j
,
auto
k
)
{
output
(
i
,
j
)
+=
input1
(
i
,
k
)
*
input2
(
k
,
j
);
});
});
return
to_gpu
(
result
);
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
lens
()[
1
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
lens
()[
1
];
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_sgemm
(
ctx
.
rbhandle
.
get
(),
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha
,
args
[
1
].
implicit
(),
ldb
,
args
[
0
].
implicit
(),
lda
,
&
beta
,
args
[
2
].
implicit
(),
ldc
);
return
args
[
2
];
}
};
...
...
@@ -216,9 +225,8 @@ struct miopen_relu
return
inputs
.
at
(
1
);
}
argument
compute
(
context
&
g
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
...
...
@@ -241,7 +249,7 @@ struct miopen_apply
void
apply
()
{
prog
->
insert_instruction
(
prog
->
begin
(),
check_context
<
miopen_
context
>
{});
prog
->
insert_instruction
(
prog
->
begin
(),
check_context
<
context
>
{});
for
(
auto
it
=
prog
->
begin
();
it
!=
prog
->
end
();
it
++
)
{
if
(
it
->
op
.
name
()
==
"convolution"
)
...
...
@@ -354,21 +362,7 @@ struct miopen_apply
}
};
struct
miopen_pass
{
std
::
string
name
()
const
{
return
"miopen::pass"
;
}
void
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
};
std
::
vector
<
pass
>
miopen_target
::
get_passes
(
context
&
)
const
{
return
{
miopen_pass
{}};
}
std
::
string
miopen_target
::
name
()
const
{
return
"miopen"
;
}
context
miopen_target
::
get_context
()
const
{
return
miopen_context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
))};
}
void
lowering
::
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
}
// namespace miopen
...
...
src/targets/miopen/rocblas.cpp
0 → 100644
View file @
76f68df4
#include <migraph/miopen/rocblas.hpp>
namespace
migraph
{
namespace
miopen
{
rocblas_handle_ptr
create_rocblas_handle_ptr
()
{
rocblas_handle
handle
;
rocblas_create_handle
(
&
handle
);
return
rocblas_handle_ptr
{
handle
};
}
}
// namespace miopen
}
// namespace migraph
src/targets/miopen/target.cpp
0 → 100644
View file @
76f68df4
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/lowering.hpp>
#include <migraph/miopen/write_literals.hpp>
#include <migraph/miopen/context.hpp>
namespace
migraph
{
namespace
miopen
{
std
::
vector
<
pass
>
target
::
get_passes
(
migraph
::
context
&
)
const
{
return
{
lowering
{},
write_literals
{}};
}
std
::
string
target
::
name
()
const
{
return
"miopen"
;
}
migraph
::
context
target
::
get_context
()
const
{
return
context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
)),
share
(
create_rocblas_handle_ptr
())};
}
}
// namespace miopen
}
// namespace migraph
src/targets/miopen/write_literals.cpp
0 → 100644
View file @
76f68df4
#include <migraph/miopen/write_literals.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/instruction.hpp>
namespace
migraph
{
namespace
miopen
{
void
write_literals
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
op
.
name
()
==
"@literal"
)
{
literal
l
=
ins
->
lit
;
auto
pre
=
p
.
add_literal
(
l
);
p
.
replace_instruction
(
ins
,
hip_write
{},
pre
);
}
}
}
}
// namespace miopen
}
// namespace migraph
test/cpu_ops_test.cpp
View file @
76f68df4
...
...
@@ -6,6 +6,25 @@
#include "test.hpp"
#include "verify.hpp"
void
batch_norm_inference_test
()
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
}};
auto
x
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
1
,
2
,
3
,
4
}});
auto
gamma
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
1
}});
auto
beta
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
0
}});
auto
mean
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
0
}});
auto
variance
=
p
.
add_literal
(
migraph
::
literal
{
s
,
{
1
}});
p
.
add_instruction
(
migraph
::
batch_norm_inference
{},
x
,
mean
,
variance
,
gamma
,
beta
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
result_vector
(
4
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1
/
(
1
+
1.0e-6
),
2
/
(
1
+
1.0e-6
),
3
/
(
1
+
1.0e-6
),
4
/
(
1
+
1.0e-6
)};
EXPECT
(
test
::
verify_range
(
result_vector
,
gold
));
}
void
exp_test
()
{
migraph
::
program
p
;
...
...
@@ -252,6 +271,63 @@ void gemm_test()
}
}
void
maxpool_test
()
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
-
2.1314404
,
-
1.63041711
,
1.54562736
,
1.04625261
,
-
1.42931843
,
-
0.48703974
,
0.4065806
,
-
0.1524526
,
1.30775225
,
0.45538983
,
-
0.06631992
,
-
1.75332725
,
1.33493888
,
0.47327688
,
0.36873096
,
1.18358743
,
-
0.34640595
,
1.22098756
,
0.01946825
,
-
0.20238149
,
0.43348005
,
-
0.67991608
,
-
0.83041084
,
0.93537551
,
0.70241445
,
-
0.5654031
,
-
1.30899191
,
-
0.26735824
,
-
0.52444768
,
1.99097753
,
1.86504853
,
-
0.26506025
,
0.26236168
,
0.43763575
,
0.95300823
,
-
1.02733946
,
-
0.74655169
,
-
0.5374338
,
-
0.28901565
,
-
0.59789604
,
0.5310151
,
0.99125904
,
0.40609556
,
-
1.57175648
,
0.22031412
,
1.45862222
,
0.53217483
,
1.39087725
,
1.00170159
,
-
0.87175864
,
-
1.7204628
,
-
1.72008383
,
-
0.38656762
,
-
0.01443311
,
1.46645272
,
-
1.39995027
,
0.22505587
,
-
0.43461126
,
-
0.05511411
,
-
0.79950953
,
-
0.01439556
,
0.08795211
,
1.18943918
,
-
0.84079367
,
-
1.73383629
,
-
0.55662078
,
-
0.30626822
,
-
0.67339015
,
0.44179603
,
0.54316711
,
0.40899998
,
-
0.27831686
,
-
1.11900508
,
-
0.0881724
,
0.35483059
,
2.36277103
,
-
0.04765317
,
-
0.36865309
,
0.73814237
,
1.47151589
,
1.36546791
,
-
0.32649881
,
-
1.0517807
,
2.24768877
,
0.68883753
,
0.58646208
,
-
0.91017133
,
-
0.50462508
,
-
0.4013325
,
-
0.72348958
,
-
0.47368807
,
0.35285577
,
-
1.01817429
,
-
0.5152272
,
0.60321307
,
0.43521205
,
-
0.23733577
,
0.66427642
,
0.82949388
,
0.82443929
,
0.71550399
,
0.34561086
,
0.68570769
,
-
0.40718508
,
-
1.20350206
,
0.15793853
,
-
2.31013632
,
-
0.07934658
,
-
0.09348056
,
0.36576006
,
2.46601582
,
0.11090943
,
0.9144392
,
0.56759721
,
-
0.22112127
,
-
0.21955389
,
0.72474903
,
-
1.28448462
,
1.53285873
,
0.37437943
,
0.31409341
,
1.95433736
,
0.91620457
,
0.86205518
,
1.24365854
,
0.19248386
,
0.22526583
,
0.13462132
,
-
0.27561715
,
-
2.06446075
,
-
0.02306402
,
-
1.38278747
,
1.1411345
,
1.31293464
,
-
1.86041689
,
1.06763375
,
-
0.26541466
,
1.4545635
,
1.11430049
,
-
0.66491818
,
0.87101674
,
0.67768967
,
-
1.02062869
,
-
1.05031872
,
-
2.2764678
,
-
2.0200038
,
0.37592548
,
-
0.26701379
,
-
0.83388507
,
0.19403623
,
1.00968623
,
0.11020003
,
1.16736257
,
-
1.1160326
,
0.47346735
,
0.6126079
,
-
0.19135755
,
1.33624589
,
-
0.29802522
,
-
0.57873946
,
-
1.06555879
,
-
0.20686582
,
1.36892557
,
-
0.19937795
,
0.8649236
,
-
1.40126073
,
1.53441942
,
0.34682792
,
-
1.31724346
,
-
1.32898355
,
2.40126371
,
0.07845283
,
1.35732043
,
-
0.63678312
,
0.39429256
,
-
1.36487007
,
-
0.31026676
,
-
0.44981545
,
-
0.28994772
,
-
0.14657612
,
-
1.75206447
,
-
0.70612341
,
1.20071781
,
-
1.64647579
,
-
0.7133292
,
0.88494766
,
0.52119428
,
-
2.77387547
,
2.07681108
,
-
0.90133125
,
0.2847338
,
0.6174528
,
-
0.20616426
,
-
0.64263535
,
-
1.08496261
,
0.54275119
,
-
0.88503587
,
0.6629802
,
1.47319221
,
-
1.05829155
,
-
0.97027361
,
-
0.93187737
,
-
1.39954746
,
-
0.52359426
,
-
0.14743951
,
1.51522756
,
0.2078452
,
-
1.28156149
,
-
1.19363916
,
-
0.78680223
,
-
0.89094824
,
1.30212069
,
-
0.77974445
,
-
0.58411664
,
0.48764706
,
-
0.67132682
};
std
::
vector
<
float
>
c
=
{
1.33493888
,
1.54562736
,
1.22098756
,
1.33493888
,
1.18358743
,
1.99097753
,
1.00170159
,
1.45862222
,
1.39087725
,
1.46645272
,
1.18943918
,
-
0.01443311
,
1.47151589
,
2.36277103
,
2.24768877
,
0.68883753
,
0.82949388
,
0.71550399
,
1.95433736
,
2.46601582
,
1.53285873
,
1.95433736
,
1.06763375
,
1.4545635
,
1.33624589
,
1.16736257
,
0.6126079
,
1.36892557
,
2.40126371
,
1.53441942
,
0.52119428
,
2.07681108
,
0.88494766
,
1.51522756
,
0.54275119
,
0.6629802
};
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
6
,
6
}};
auto
al
=
p
.
add_literal
(
migraph
::
literal
{
a_shape
,
a
});
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
3
,
2
}}},
al
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
cout
<<
result
.
get_shape
()
<<
std
::
endl
;
std
::
vector
<
float
>
results_vector
(
36
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
float
tol
=
1e-6
;
for
(
int
i
=
0
;
i
<
results_vector
.
size
();
i
++
)
{
// std::cout << results_vector[i] << " " << c[i] << std::endl;
EXPECT
(
std
::
abs
(
results_vector
[
i
]
-
c
[
i
])
<
tol
);
}
}
void
softmax_test
()
{
migraph
::
program
p
;
...
...
@@ -564,7 +640,9 @@ int main()
transpose_test
();
contiguous_test
();
softmax_test
();
// maxpool_test();
conv2d_test
();
conv2d_padding_test
();
conv2d_padding_stride_test
();
batch_norm_inference_test
();
}
test/miopen/miopen.cpp
View file @
76f68df4
...
...
@@ -3,7 +3,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/miopen/
miopen_
target.hpp>
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/manage_ptr.hpp>
...
...
@@ -27,7 +27,7 @@ migraph::argument run_gpu()
{
V
v
;
auto
p
=
v
.
create_program
();
p
.
compile
(
migraph
::
miopen
::
miopen_
target
{});
p
.
compile
(
migraph
::
miopen
::
target
{});
auto
m
=
v
.
create_params
();
for
(
auto
&&
e
:
m
)
...
...
@@ -49,6 +49,23 @@ void verify_program()
visit_all
(
cpu_arg
,
gpu_arg
)([](
auto
cpu
,
auto
gpu
)
{
EXPECT
(
test
::
verify_range
(
cpu
,
gpu
));
});
}
struct
test_literals
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
input
=
p
.
add_literal
(
generate_literal
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}}));
auto
weights
=
p
.
add_literal
(
generate_literal
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}}));
auto
conv
=
p
.
add_instruction
(
migraph
::
convolution
{},
input
,
weights
);
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
conv
);
return
p
;
}
migraph
::
program
::
parameter_map
create_params
()
const
{
return
{};
}
};
struct
test_add
{
migraph
::
program
create_program
()
const
...
...
tools/include/context.hpp
View file @
76f68df4
#ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
<%
...
...
tools/include/operation.hpp
View file @
76f68df4
...
...
@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace
migraph
{
...
...
@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
<%
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
%>
...
...
tools/te.py
View file @
76f68df4
...
...
@@ -213,16 +213,21 @@ def internal_name(name):
else
:
return
name
def
generate_call
(
m
,
friend
):
def
generate_call
(
m
,
friend
,
indirect
):
if
m
[
'name'
].
startswith
(
'operator'
):
op
=
m
[
'name'
][
8
:]
args
=
m
[
'args'
]
if
','
in
args
:
return
args
.
replace
(
','
,
op
)
else
:
return
string
.
Template
(
'${op}${arg
a
}'
).
substitute
(
op
=
op
,
args
=
args
)
return
string
.
Template
(
'${op}${arg
s
}'
).
substitute
(
op
=
op
,
args
=
args
)
if
friend
:
return
string
.
Template
(
'${name}(${args})'
).
substitute
(
m
)
if
indirect
:
if
m
[
'args'
]:
return
string
.
Template
(
'${default}(private_detail_te_value, ${args})'
).
substitute
(
m
)
else
:
return
string
.
Template
(
'${default}(private_detail_te_value)'
).
substitute
(
m
)
return
string
.
Template
(
'private_detail_te_value.${name}(${args})'
).
substitute
(
m
)
def
convert_member
(
d
,
struct_name
):
...
...
@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params
=
[]
skip
=
False
friend
=
False
indirect
=
False
if
'friend'
in
d
[
name
]:
friend
=
True
skip
=
True
if
'default'
in
d
[
name
]:
indirect
=
True
for
x
in
d
[
name
]:
t
=
d
[
name
][
x
]
if
x
==
'return'
:
...
...
@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member
[
'member_const'
]
=
'const'
elif
x
==
'friend'
:
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
member
[
'default'
]
=
t
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
.
startswith
(
'__'
)
and
x
.
endswith
(
'__'
):
continue
else
:
use_member
=
not
(
skip
and
struct_name
==
trim_type_name
(
t
))
arg_name
=
x
...
...
@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'params'
]
=
','
.
join
(
params
)
member
[
'member_params'
]
=
','
.
join
(
member_params
)
member
[
'call'
]
=
generate_call
(
member
,
friend
)
member
[
'call'
]
=
generate_call
(
member
,
friend
,
indirect
)
return
member
return
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment