Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
09818ae6
Commit
09818ae6
authored
Jul 07, 2022
by
Paul
Browse files
Merge branch 'develop' into fuse-horiz-contiguous
parents
6545452a
bd503d89
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
394 additions
and
84 deletions
+394
-84
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+4
-0
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+22
-7
src/module.cpp
src/module.cpp
+16
-9
src/program.cpp
src/program.cpp
+7
-5
src/serialize.cpp
src/serialize.cpp
+2
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+1
-1
src/targets/cpu/write_literals.cpp
src/targets/cpu/write_literals.cpp
+2
-0
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+3
-0
src/targets/gpu/deconvolution.cpp
src/targets/gpu/deconvolution.cpp
+70
-29
src/targets/gpu/include/migraphx/gpu/deconvolution.hpp
src/targets/gpu/include/migraphx/gpu/deconvolution.hpp
+4
-4
src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp
src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp
+2
-2
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+107
-0
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+34
-0
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+45
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+2
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+2
-3
src/targets/gpu/quant_convolution.cpp
src/targets/gpu/quant_convolution.cpp
+61
-21
No files found.
src/include/migraphx/module.hpp
View file @
09818ae6
...
@@ -164,6 +164,10 @@ struct module
...
@@ -164,6 +164,10 @@ struct module
instruction_ref
replace_return
(
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
replace_return
(
std
::
vector
<
instruction_ref
>
args
);
instruction_ref
insert_literal
(
instruction_ref
ins
,
literal
l
);
instruction_ref
insert_parameter
(
instruction_ref
ins
,
std
::
string
name
,
shape
s
);
std
::
vector
<
std
::
string
>
get_parameter_names
()
const
;
std
::
vector
<
std
::
string
>
get_parameter_names
()
const
;
shape
get_parameter_shape
(
std
::
string
name
)
const
;
shape
get_parameter_shape
(
std
::
string
name
)
const
;
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
09818ae6
...
@@ -42,11 +42,12 @@ namespace op {
...
@@ -42,11 +42,12 @@ namespace op {
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
steps
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
return
pack
(
f
(
self
.
axes
,
"axes"
)
,
f
(
self
.
steps
,
"steps"
)
);
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -73,6 +74,9 @@ struct unsqueeze
...
@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
}
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
...
@@ -80,16 +84,27 @@ struct unsqueeze
...
@@ -80,16 +84,27 @@ struct unsqueeze
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
for
(
auto
i
:
range
(
new_size
))
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
{
new_lens
[
i
]
=
1
;
std
::
int64_t
step
=
1
;
if
(
p
==
0
)
// unsqueeze on the first axes
if
(
axis_idx
<
steps
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
{
new_strides
[
i
]
=
old_lens
[
0
]
*
old_strides
[
0
];
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
}
else
// unsqueeze on middle or last axes
else
{
{
new_strides
[
i
]
=
(
p
<
old_strides
.
size
())
?
old_strides
[
p
-
1
]
:
1
;
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
}
}
else
else
...
...
src/module.cpp
View file @
09818ae6
...
@@ -439,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
...
@@ -439,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
));
}
}
instruction_ref
module
::
add_literal
(
literal
l
)
instruction_ref
module
::
add_literal
(
literal
l
)
{
return
insert_literal
(
begin
(),
std
::
move
(
l
));
}
{
impl
->
emplace_front
(
std
::
move
(
l
));
return
impl
->
instructions
.
begin
();
}
instruction_ref
module
::
add_outline
(
const
shape
&
s
)
instruction_ref
module
::
add_outline
(
const
shape
&
s
)
{
{
...
@@ -453,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
...
@@ -453,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref
module
::
add_parameter
(
std
::
string
name
,
shape
s
)
instruction_ref
module
::
add_parameter
(
std
::
string
name
,
shape
s
)
{
{
assert
(
get_parameter_shape
(
name
)
==
shape
{});
return
insert_parameter
(
begin
(),
std
::
move
(
name
),
std
::
move
(
s
));
impl
->
push_front
({
builtin
::
param
{
std
::
move
(
name
),
impl
->
nparams
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
return
impl
->
instructions
.
begin
();
}
}
instruction_ref
module
::
add_return
(
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
module
::
add_return
(
std
::
vector
<
instruction_ref
>
args
)
...
@@ -469,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
...
@@ -469,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return
result
;
return
result
;
}
}
instruction_ref
module
::
insert_literal
(
instruction_ref
ins
,
literal
l
)
{
impl
->
emplace
(
ins
,
std
::
move
(
l
));
return
std
::
prev
(
ins
);
}
instruction_ref
module
::
insert_parameter
(
instruction_ref
ins
,
std
::
string
name
,
shape
s
)
{
assert
(
get_parameter_shape
(
name
)
==
shape
{});
impl
->
insert
(
ins
,
{
builtin
::
param
{
std
::
move
(
name
),
impl
->
nparams
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
return
std
::
prev
(
ins
);
}
instruction_ref
module
::
replace_return
(
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
module
::
replace_return
(
std
::
vector
<
instruction_ref
>
args
)
{
{
auto
last
=
std
::
prev
(
this
->
end
());
auto
last
=
std
::
prev
(
this
->
end
());
...
...
src/program.cpp
View file @
09818ae6
...
@@ -504,12 +504,14 @@ static void mod_from_val(module_ref mod,
...
@@ -504,12 +504,14 @@ static void mod_from_val(module_ref mod,
if
(
name
==
"@param"
)
if
(
name
==
"@param"
)
{
{
output
=
mod
->
add_parameter
(
fields
[
"parameter"
].
to
<
std
::
string
>
(),
output
=
mod
->
insert_parameter
(
mod
->
end
(),
fields
[
"parameter"
].
to
<
std
::
string
>
(),
migraphx
::
from_value
<
shape
>
(
node
.
at
(
"shape"
)));
migraphx
::
from_value
<
shape
>
(
node
.
at
(
"shape"
)));
}
}
else
if
(
name
==
"@literal"
)
else
if
(
name
==
"@literal"
)
{
{
output
=
mod
->
add_literal
(
migraphx
::
from_value
<
literal
>
(
node
.
at
(
"literal"
)));
output
=
mod
->
insert_literal
(
mod
->
end
(),
migraphx
::
from_value
<
literal
>
(
node
.
at
(
"literal"
)));
}
}
else
else
{
{
...
@@ -544,11 +546,11 @@ static void mod_from_val(module_ref mod,
...
@@ -544,11 +546,11 @@ static void mod_from_val(module_ref mod,
}
}
else
if
(
module_inputs
.
empty
())
else
if
(
module_inputs
.
empty
())
{
{
output
=
mod
->
add
_instruction
(
op
,
inputs
);
output
=
mod
->
insert
_instruction
(
mod
->
end
(),
op
,
inputs
);
}
}
else
else
{
{
output
=
mod
->
add
_instruction
(
op
,
inputs
,
module_inputs
);
output
=
mod
->
insert
_instruction
(
mod
->
end
(),
op
,
inputs
,
module_inputs
);
}
}
}
}
output
->
set_normalized
(
normalized
);
output
->
set_normalized
(
normalized
);
...
...
src/serialize.cpp
View file @
09818ae6
...
@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
...
@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
result
[
"shape"
]
=
migraphx
::
to_value
(
rd
.
get_shape
());
result
[
"shape"
]
=
migraphx
::
to_value
(
rd
.
get_shape
());
if
(
rd
.
get_shape
().
type
()
==
shape
::
tuple_type
)
if
(
rd
.
get_shape
().
type
()
==
shape
::
tuple_type
)
result
[
"sub"
]
=
migraphx
::
to_value
(
rd
.
get_sub_objects
());
result
[
"sub"
]
=
migraphx
::
to_value
(
rd
.
get_sub_objects
());
else
else
if
(
not
rd
.
empty
())
result
[
"data"
]
=
migraphx
::
value
::
binary
(
rd
.
data
(),
rd
.
get_shape
().
bytes
());
result
[
"data"
]
=
migraphx
::
value
::
binary
(
rd
.
data
(),
rd
.
get_shape
().
bytes
());
v
=
result
;
v
=
result
;
}
}
...
@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
...
@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
literal
l
=
migraphx
::
from_value
<
literal
>
(
v
);
literal
l
=
migraphx
::
from_value
<
literal
>
(
v
);
a
=
l
.
get_argument
();
a
=
l
.
get_argument
();
}
}
else
else
if
(
v
.
contains
(
"sub"
))
{
{
a
=
migraphx
::
from_value
<
std
::
vector
<
argument
>>
(
v
.
at
(
"sub"
));
a
=
migraphx
::
from_value
<
std
::
vector
<
argument
>>
(
v
.
at
(
"sub"
));
}
}
...
...
src/simplify_reshapes.cpp
View file @
09818ae6
...
@@ -272,7 +272,7 @@ struct find_concat_transpose
...
@@ -272,7 +272,7 @@ struct find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose
_shape
(
)));
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
name
(
"
transpose
"
)));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
...
...
src/targets/cpu/write_literals.cpp
View file @
09818ae6
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -52,6 +53,7 @@ struct cpu_literal
...
@@ -52,6 +53,7 @@ struct cpu_literal
return
os
;
return
os
;
}
}
};
};
MIGRAPHX_REGISTER_OP
(
cpu_literal
);
void
write_literals
::
apply
(
module
&
m
)
const
void
write_literals
::
apply
(
module
&
m
)
const
{
{
...
...
src/targets/gpu/compile_gen.cpp
View file @
09818ae6
...
@@ -43,6 +43,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
...
@@ -43,6 +43,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
vectorize
vectorize
::
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
vectorize
vectorize
::
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
()[
axis
]
==
1
;
}))
return
{
1
,
axis
};
auto
sizes
=
vector_sizes
(
inputs
);
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
std
::
transform
(
inputs
.
begin
(),
...
...
src/targets/gpu/deconvolution.cpp
View file @
09818ae6
...
@@ -59,31 +59,30 @@ argument miopen_deconvolution::compute(context& ctx,
...
@@ -59,31 +59,30 @@ argument miopen_deconvolution::compute(context& ctx,
auto
w_desc
=
make_tensor
(
reshape_if_1d
(
args
[
1
].
get_shape
()));
auto
w_desc
=
make_tensor
(
reshape_if_1d
(
args
[
1
].
get_shape
()));
auto
y_desc
=
make_tensor
(
reshape_if_1d
(
output_shape
));
auto
y_desc
=
make_tensor
(
reshape_if_1d
(
output_shape
));
float
alpha
=
1
;
if
(
solution_id
==
0
)
float
beta
=
0
;
MIGRAPHX_THROW
(
"MIOpen Deconvolution: invalid solution ID"
);
auto
status
=
miopenConvolutionForward
(
ctx
.
get_stream
().
get_miopen
(),
&
alpha
,
auto
status
=
miopenConvolutionForwardImmediate
(
ctx
.
get_stream
().
get_miopen
(),
x_desc
.
get
(),
args
[
0
].
implicit
(),
w_desc
.
get
(),
w_desc
.
get
(),
args
[
1
].
implicit
(),
args
[
1
].
implicit
(),
x_desc
.
get
(),
args
[
0
].
implicit
(),
cd
.
get
(),
cd
.
get
(),
algo
,
&
beta
,
y_desc
.
get
(),
y_desc
.
get
(),
args
[
3
].
implicit
(),
args
[
3
].
implicit
(),
args
[
2
].
implicit
(),
args
[
2
].
implicit
(),
args
[
2
].
get_shape
().
bytes
());
args
[
2
].
get_shape
().
bytes
(),
solution_id
);
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"
R
unning
de
convolution failed"
);
MIGRAPHX_THROW
(
"
MIOpen Deconvolution: r
unning convolution failed"
);
return
args
[
3
];
return
args
[
3
];
}
}
shape
miopen_deconvolution
::
compile
(
context
&
ctx
,
shape
miopen_deconvolution
::
find
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
{
{
shape
workspace_shape
{};
shape
workspace_shape
{};
auto
x_desc
=
make_tensor
(
reshape_if_1d
(
inputs
[
0
]));
auto
x_desc
=
make_tensor
(
reshape_if_1d
(
inputs
[
0
]));
auto
w_desc
=
make_tensor
(
reshape_if_1d
(
inputs
[
1
]));
auto
w_desc
=
make_tensor
(
reshape_if_1d
(
inputs
[
1
]));
auto
y_desc
=
make_tensor
(
reshape_if_1d
(
output_shape
));
auto
y_desc
=
make_tensor
(
reshape_if_1d
(
output_shape
));
...
@@ -119,9 +118,35 @@ shape miopen_deconvolution::compile(context& ctx,
...
@@ -119,9 +118,35 @@ shape miopen_deconvolution::compile(context& ctx,
workspace_size
,
workspace_size
,
false
);
false
);
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"Find deconvolution failed"
);
MIGRAPHX_THROW
(
"MIOpen Deconvolution: find convolution failed"
);
handle
=
ctx
.
get_stream
().
get_miopen
();
algo
=
perf
.
fwd_algo
;
algo
=
perf
.
fwd_algo
;
size_t
solution_count
;
status
=
miopenConvolutionForwardGetSolutionCount
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
&
solution_count
);
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Deconvolution: get solution count failed"
);
std
::
vector
<
miopenConvSolution_t
>
solutions
(
solution_count
);
status
=
miopenConvolutionForwardGetSolution
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
solution_count
,
&
solution_count
,
solutions
.
data
());
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Deconvolution: get solution failed"
);
solution_id
=
solutions
.
front
().
solution_id
;
return
shape
{
shape
::
int8_type
,
{
perf
.
memory
}};
return
shape
{
shape
::
int8_type
,
{
perf
.
memory
}};
}
}
...
@@ -129,13 +154,29 @@ void miopen_deconvolution::finalize(context& ctx,
...
@@ -129,13 +154,29 @@ void miopen_deconvolution::finalize(context& ctx,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
std
::
vector
<
shape
>
inputs
)
{
{
if
(
handle
==
ctx
.
get_stream
().
get_miopen
())
if
(
cd
==
nullptr
)
return
;
cd
=
make_deconv
(
op
);
if
(
solution_id
==
0
)
{
// Check that workspace hasn't changed
// Check that workspace hasn't changed
auto
size
=
inputs
.
at
(
2
).
bytes
();
auto
size
=
inputs
.
at
(
2
).
bytes
();
auto
ws
=
compile
(
ctx
,
output_shape
,
std
::
move
(
inputs
)
)
;
auto
ws
=
find
(
ctx
,
output_shape
,
inputs
);
if
(
ws
.
bytes
()
>
size
)
if
(
ws
.
bytes
()
>
size
)
MIGRAPHX_THROW
(
"Workspace has changed during finalization."
);
MIGRAPHX_THROW
(
"MIOpen Deconvolution: workspace has changed during finalization."
);
}
auto
x_desc
=
make_tensor
(
reshape_if_1d
(
inputs
[
0
]));
auto
w_desc
=
make_tensor
(
reshape_if_1d
(
inputs
[
1
]));
auto
y_desc
=
make_tensor
(
reshape_if_1d
(
output_shape
));
auto
status
=
miopenConvolutionForwardCompileSolution
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
solution_id
);
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Deconvolution: compile solution failed"
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/deconvolution.hpp
View file @
09818ae6
...
@@ -39,20 +39,20 @@ struct miopen_deconvolution
...
@@ -39,20 +39,20 @@ struct miopen_deconvolution
op
::
deconvolution
op
;
op
::
deconvolution
op
;
shared
<
convolution_descriptor
>
cd
;
shared
<
convolution_descriptor
>
cd
;
miopenConvFwdAlgorithm_t
algo
{};
miopenConvFwdAlgorithm_t
algo
{};
miopenHandle_t
handle
=
nullptr
;
uint64_t
solution_id
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
// TODO: Add algo
return
pack_join
(
op
::
deconvolution
::
reflect
(
self
.
op
,
f
),
return
op
::
convolution
::
reflect
(
self
.
op
,
f
);
pack
(
f
(
self
.
solution_id
,
"solution_id"
))
);
}
}
std
::
string
name
()
const
{
return
"gpu::deconv"
;
}
std
::
string
name
()
const
{
return
"gpu::deconv"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
shape
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
shape
find
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
...
...
src/targets/gpu/include/migraphx/gpu/quant_convolution.hpp
View file @
09818ae6
...
@@ -41,7 +41,7 @@ struct miopen_quant_convolution
...
@@ -41,7 +41,7 @@ struct miopen_quant_convolution
bool
int8_x4_format
=
false
;
bool
int8_x4_format
=
false
;
shared
<
convolution_descriptor
>
cd
;
shared
<
convolution_descriptor
>
cd
;
miopenConvFwdAlgorithm_t
algo
{};
miopenConvFwdAlgorithm_t
algo
{};
miopenHandle_t
handle
=
nullptr
;
uint64_t
solution_id
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -55,7 +55,7 @@ struct miopen_quant_convolution
...
@@ -55,7 +55,7 @@ struct miopen_quant_convolution
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
shape
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
shape
find
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
);
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
...
...
src/targets/gpu/jit/softmax.cpp
0 → 100644
View file @
09818ae6
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
softmax_compiler
:
compiler
<
softmax_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"softmax"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
v
.
at
(
"axis"
).
to
<
int64_t
>
();
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
faxis
==
axis
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
auto
src
=
interpolate_string
(
softmax_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
09818ae6
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -213,6 +214,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
...
@@ -213,6 +214,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
}
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
transform_i
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
return
sequence_c
<
sizeof
...(
Xs
)
>
(
[
=
](
auto
...
is
)
{
return
integral_const_array
<
T
,
f
(
Xs
,
is
)...
>
{};
});
}
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
09818ae6
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/
array
.hpp>
#include <migraphx/kernels/
integral_constant
.hpp>
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
#define MIGRAPHX_RETURNS(...) \
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
09818ae6
...
@@ -175,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f)
...
@@ -175,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f)
};
};
}
}
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
compute_reduce_axis
()
{
constexpr
auto
lens
=
transform_i
(
get_shape_c
<
Input
>
{}.
lens
,
[](
index_int
x
,
index_int
i
)
->
index_int
{
if
(
i
==
Axis
)
return
1
;
return
x
;
});
return
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
}
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
struct
block
struct
block
{
{
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -201,6 +216,14 @@ struct block
...
@@ -201,6 +216,14 @@ struct block
if
(
idx
.
local
==
0
)
if
(
idx
.
local
==
0
)
f
();
f
();
}
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -247,6 +270,17 @@ struct lane
...
@@ -247,6 +270,17 @@ struct lane
{
{
f
();
f
();
}
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
09818ae6
...
@@ -32,6 +32,7 @@ namespace migraphx {
...
@@ -32,6 +32,7 @@ namespace migraphx {
template
<
class
Lens
,
class
Strides
>
template
<
class
Lens
,
class
Strides
>
struct
shape
struct
shape
{
{
using
shape_type
=
shape
;
using
index_array
=
typename
Lens
::
base_array
;
using
index_array
=
typename
Lens
::
base_array
;
Lens
lens
=
{};
Lens
lens
=
{};
Strides
strides
=
{};
Strides
strides
=
{};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
0 → 100644
View file @
09818ae6
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
batch_max
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
batch_max
);
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
batch_max
)
/
batch_sum
;
})(
output
,
input
);
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
09818ae6
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/lowering.cpp
View file @
09818ae6
...
@@ -186,7 +186,6 @@ struct miopen_apply
...
@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"topk"
);
add_extend_op
(
"topk"
);
add_batch_norm_inference_op
();
add_batch_norm_inference_op
();
...
@@ -301,7 +300,7 @@ struct miopen_apply
...
@@ -301,7 +300,7 @@ struct miopen_apply
auto
&&
op
=
any_cast
<
op
::
deconvolution
>
(
ins
->
get_operator
());
auto
&&
op
=
any_cast
<
op
::
deconvolution
>
(
ins
->
get_operator
());
auto
conv
=
miopen_deconvolution
{
op
,
make_deconv
(
op
)};
auto
conv
=
miopen_deconvolution
{
op
,
make_deconv
(
op
)};
auto
ws
=
conv
.
compile
(
get_context
(),
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
auto
ws
=
conv
.
find
(
get_context
(),
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
auto
workspace
=
insert_allocation
(
ins
,
ws
);
auto
workspace
=
insert_allocation
(
ins
,
ws
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
...
@@ -332,7 +331,7 @@ struct miopen_apply
...
@@ -332,7 +331,7 @@ struct miopen_apply
miopen_quant_convolution
conv
;
miopen_quant_convolution
conv
;
auto
compile_quant_conv_with_format
=
[
&
](
bool
format
)
{
auto
compile_quant_conv_with_format
=
[
&
](
bool
format
)
{
conv
=
miopen_quant_convolution
{
op
,
format
,
make_conv
(
op
)};
conv
=
miopen_quant_convolution
{
op
,
format
,
make_conv
(
op
)};
ws
=
conv
.
compile
(
get_context
(),
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
ws
=
conv
.
find
(
get_context
(),
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
};
};
try
try
...
...
src/targets/gpu/quant_convolution.cpp
View file @
09818ae6
...
@@ -67,7 +67,7 @@ argument miopen_quant_convolution::compute(context& ctx,
...
@@ -67,7 +67,7 @@ argument miopen_quant_convolution::compute(context& ctx,
return
args
[
3
];
return
args
[
3
];
}
}
shape
miopen_quant_convolution
::
compile
(
context
&
ctx
,
shape
miopen_quant_convolution
::
find
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
std
::
vector
<
shape
>
inputs
)
{
{
...
@@ -92,8 +92,8 @@ shape miopen_quant_convolution::compile(context& ctx,
...
@@ -92,8 +92,8 @@ shape miopen_quant_convolution::compile(context& ctx,
x_shape
=
pack_int8_shape
(
x_shape
);
x_shape
=
pack_int8_shape
(
x_shape
);
w_shape
=
pack_int8_shape
(
w_shape
);
w_shape
=
pack_int8_shape
(
w_shape
);
}
}
auto
arg_vec4_x
=
to_gpu
(
generate_argument
(
x_shape
));
auto
x
=
to_gpu
(
generate_argument
(
x_shape
));
auto
arg_vec4_w
=
to_gpu
(
generate_argument
(
w_shape
));
auto
w
=
to_gpu
(
generate_argument
(
w_shape
));
auto
y
=
allocate_gpu
(
output_shape
);
auto
y
=
allocate_gpu
(
output_shape
);
auto
workspace
=
allocate_gpu
(
workspace_shape
);
auto
workspace
=
allocate_gpu
(
workspace_shape
);
...
@@ -101,9 +101,9 @@ shape miopen_quant_convolution::compile(context& ctx,
...
@@ -101,9 +101,9 @@ shape miopen_quant_convolution::compile(context& ctx,
miopenConvAlgoPerf_t
perf
;
miopenConvAlgoPerf_t
perf
;
auto
status
=
miopenFindConvolutionForwardAlgorithm
(
ctx
.
get_stream
().
get_miopen
(),
auto
status
=
miopenFindConvolutionForwardAlgorithm
(
ctx
.
get_stream
().
get_miopen
(),
x_desc
.
get
(),
x_desc
.
get
(),
arg_vec4_
x
.
implicit
(),
x
.
implicit
(),
w_desc
.
get
(),
w_desc
.
get
(),
arg_vec4_
w
.
implicit
(),
w
.
implicit
(),
cd
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
y_desc
.
get
(),
y
.
implicit
(),
y
.
implicit
(),
...
@@ -114,11 +114,35 @@ shape miopen_quant_convolution::compile(context& ctx,
...
@@ -114,11 +114,35 @@ shape miopen_quant_convolution::compile(context& ctx,
workspace_size
,
workspace_size
,
false
);
false
);
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
{
MIGRAPHX_THROW
(
"MIOpen Quant Convolution: find convolution failed"
);
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: find convolution failed"
);
}
handle
=
ctx
.
get_stream
().
get_miopen
();
algo
=
perf
.
fwd_algo
;
algo
=
perf
.
fwd_algo
;
size_t
solution_count
;
status
=
miopenConvolutionForwardGetSolutionCount
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
&
solution_count
);
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Quant Convolution: get solution count failed"
);
std
::
vector
<
miopenConvSolution_t
>
solutions
(
solution_count
);
status
=
miopenConvolutionForwardGetSolution
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
solution_count
,
&
solution_count
,
solutions
.
data
());
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Quant Convolution: get solution failed"
);
solution_id
=
solutions
.
front
().
solution_id
;
return
shape
{
shape
::
int8_type
,
{
perf
.
memory
}};
return
shape
{
shape
::
int8_type
,
{
perf
.
memory
}};
}
}
...
@@ -126,13 +150,29 @@ void miopen_quant_convolution::finalize(context& ctx,
...
@@ -126,13 +150,29 @@ void miopen_quant_convolution::finalize(context& ctx,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
std
::
vector
<
shape
>
inputs
)
{
{
if
(
handle
==
ctx
.
get_stream
().
get_miopen
())
if
(
cd
==
nullptr
)
return
;
cd
=
make_conv
(
op
);
if
(
solution_id
==
0
)
{
// Check that workspace hasn't changed
// Check that workspace hasn't changed
auto
size
=
inputs
.
at
(
2
).
bytes
();
auto
size
=
inputs
.
at
(
2
).
bytes
();
auto
ws
=
compile
(
ctx
,
output_shape
,
std
::
move
(
inputs
)
)
;
auto
ws
=
find
(
ctx
,
output_shape
,
inputs
);
if
(
ws
.
bytes
()
>
size
)
if
(
ws
.
bytes
()
>
size
)
MIGRAPHX_THROW
(
"Workspace has changed during finalization."
);
MIGRAPHX_THROW
(
"MIOpen Quant Convolution: workspace has changed during finalization."
);
}
auto
x_desc
=
make_tensor
(
inputs
[
0
],
int8_x4_format
);
auto
w_desc
=
make_tensor
(
inputs
[
1
],
int8_x4_format
);
auto
y_desc
=
make_tensor
(
output_shape
);
auto
status
=
miopenConvolutionForwardCompileSolution
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
solution_id
);
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Quant Convolution: compile solution failed"
);
}
}
shape
miopen_quant_convolution
::
pack_int8_shape
(
const
shape
&
s
)
const
shape
miopen_quant_convolution
::
pack_int8_shape
(
const
shape
&
s
)
const
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment