Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a92cde80
Commit
a92cde80
authored
Dec 14, 2022
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/jit-reduce-reg' into ck-gsg
parents
fabbb929
dae94657
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
367 additions
and
49 deletions
+367
-49
src/driver/main.cpp
src/driver/main.cpp
+8
-2
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+6
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-0
src/module.cpp
src/module.cpp
+88
-0
src/program.cpp
src/program.cpp
+19
-0
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+48
-7
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+7
-9
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+159
-25
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+24
-0
No files found.
src/driver/main.cpp
View file @
a92cde80
...
@@ -109,8 +109,12 @@ struct loader
...
@@ -109,8 +109,12 @@ struct loader
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
output_type
,
ap
(
output_type
,
{
"--cpp"
},
{
"--cpp"
},
ap
.
help
(
"Print out the program as
cpp
program."
),
ap
.
help
(
"Print out the program as
C++
program."
),
ap
.
set_value
(
"cpp"
));
ap
.
set_value
(
"cpp"
));
ap
(
output_type
,
{
"--python"
,
"--py"
},
ap
.
help
(
"Print out the program as python program."
),
ap
.
set_value
(
"py"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
ap
(
output_type
,
{
"--text"
},
{
"--text"
},
...
@@ -259,7 +263,9 @@ struct loader
...
@@ -259,7 +263,9 @@ struct loader
type
=
"binary"
;
type
=
"binary"
;
}
}
if
(
type
==
"cpp"
)
if
(
type
==
"py"
)
p
.
print_py
(
*
os
);
else
if
(
type
==
"cpp"
)
p
.
print_cpp
(
*
os
);
p
.
print_cpp
(
*
os
);
else
if
(
type
==
"graphviz"
)
else
if
(
type
==
"graphviz"
)
p
.
print_graph
(
*
os
,
brief
);
p
.
print_graph
(
*
os
,
brief
);
...
...
src/include/migraphx/module.hpp
View file @
a92cde80
...
@@ -205,6 +205,12 @@ struct module
...
@@ -205,6 +205,12 @@ struct module
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
print_cpp
(
std
::
ostream
&
os
,
print_cpp
(
std
::
ostream
&
os
,
...
...
src/include/migraphx/program.hpp
View file @
a92cde80
...
@@ -115,6 +115,7 @@ struct program
...
@@ -115,6 +115,7 @@ struct program
print_func
)
const
;
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
...
...
src/module.cpp
View file @
a92cde80
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
...
@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
os
<<
")"
;
}
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape("
<<
s
.
type_string
()
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
os
<<
"}"
;
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
const
std
::
string
&
mname
,
...
@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
...
@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return
names
;
return
names
;
}
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
src/program.cpp
View file @
a92cde80
...
@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
...
@@ -854,6 +854,25 @@ void program::print_graph(std::ostream& os, bool brief) const
mm
->
print_graph
(
os
,
brief
);
mm
->
print_graph
(
os
,
brief
);
}
}
void
program
::
print_py
(
std
::
ostream
&
os
)
const
{
auto
vec_modules
=
this
->
get_modules
();
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
os
<<
"p = migraphx.program()
\n
"
;
for
(
auto
&
mod
:
vec_modules
)
{
std
::
string
var_name
=
"m"
+
mod
->
name
();
os
<<
var_name
<<
" = "
;
if
(
mod
->
name
()
==
"main"
)
os
<<
"p.get_main_module()"
;
else
os
<<
"p.create_module(
\"
"
<<
mod
->
name
()
<<
"
\"
);"
;
os
<<
std
::
endl
;
names
=
mod
->
print_py
(
os
,
var_name
,
names
);
os
<<
std
::
endl
;
}
}
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
void
program
::
print_cpp
(
std
::
ostream
&
os
)
const
{
{
auto
vec_modules
=
this
->
get_modules
();
auto
vec_modules
=
this
->
get_modules
();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
a92cde80
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -137,28 +138,68 @@ struct index
...
@@ -137,28 +138,68 @@ struct index
return
(
n
-
_c
<
1
>
)
/
stride
+
_c
<
1
>
;
return
(
n
-
_c
<
1
>
)
/
stride
+
_c
<
1
>
;
}
}
template
<
class
N
>
constexpr
auto
max_global_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nglobal
());
}
template
<
class
N
>
constexpr
auto
max_local_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nlocal
());
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
d
)
->
decltype
(
f
(
i
,
d
))
{
return
f
(
i
,
d
);
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
)
->
decltype
(
f
(
i
))
{
return
f
(
i
);
}
template
<
class
F
,
class
N
,
class
Stride
>
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
{
MIGRAPHX_ASSERT
(
start
<
stride
);
MIGRAPHX_ASSERT
(
start
<
stride
);
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{}
and
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{})
max_stride_iterations
(
n
,
stride
)
==
1
)
{
{
if
constexpr
(
stride
>
n
)
if
constexpr
(
max_
stride
_iterations
(
n
,
stride
)
==
1
)
{
{
if
(
start
<
n
)
if
constexpr
(
stride
>
n
)
f
(
start
);
{
if
(
start
<
n
)
invoke_loop
(
f
,
start
,
_c
<
0
>
);
}
else
{
invoke_loop
(
f
,
start
,
_c
<
0
>
);
}
}
}
else
else
{
{
f
(
start
);
static_assert
(
max_stride_iterations
(
n
,
stride
)
<
64
);
sequence
(
max_stride_iterations
(
n
,
stride
),
[
&
](
auto
...
ks
)
{
fold
([
&
](
auto
d
,
auto
k
)
{
auto
i
=
start
+
stride
*
k
;
if
(
i
<
n
)
invoke_loop
(
f
,
i
,
d
);
return
d
+
_c
<
1
>
;
})(
_c
<
0
>
,
ks
...);
});
}
}
}
}
else
else
{
{
index_int
k
=
0
;
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
{
{
f
(
i
);
invoke_loop
(
f
,
i
,
k
);
k
++
;
}
}
}
}
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
a92cde80
...
@@ -48,26 +48,24 @@ __device__ void generic_binary_layernorm(
...
@@ -48,26 +48,24 @@ __device__ void generic_binary_layernorm(
{
{
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
means
=
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x
)
{
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x1
,
auto
x2
)
{
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
auto
x
=
op
(
x1
,
x2
);
})(
input
);
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
})(
input1
,
input2
);
auto
mean_x
=
means
[
0
];
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
eps
;
// implicit conversion for eps
value_type
eps_val
=
eps
;
// implicit conversion for eps
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
x
=
op
(
x1
,
x2
);
auto
m
=
x
-
mean_x
;
auto
m
=
x
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + epsilon)
// m * rsqrt(mean(m ^ 2) + epsilon)
y
=
compute
(
m
*
rsqrt
(
variance
+
eps_val
),
xs
...);
y
=
compute
(
m
*
rsqrt
(
variance
+
eps_val
),
xs
...);
})(
output
,
input
1
,
input2
,
inputs
...);
})(
output
,
input
,
inputs
...);
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
a92cde80
...
@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
...
@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else
#else
constexpr
index_int
lanes_per_thread
=
64
;
constexpr
index_int
lanes_per_thread
=
64
;
#endif
#endif
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
dpp_reduce
(
x
,
op
);
const
auto
ldsidx
=
idx
.
local
/
lanes_per_thread
;
const
auto
ldsidx
=
idx
.
local
/
lanes_per_thread
;
...
@@ -131,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
...
@@ -131,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
f
(
0
));
__shared__
type
buffer
[
idx
.
max_nlocal
()];
__shared__
type
buffer
[
idx
.
max_nlocal
()];
type
x
=
init
;
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
buffer
[
idx
.
local
]
=
x
;
buffer
[
idx
.
local
]
=
x
;
__syncthreads
();
__syncthreads
();
...
@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
...
@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace
reduce
{
namespace
reduce
{
struct
inner_storage_tag
{
};
template
<
class
T
>
using
is_inner_storage
=
is_base_of
<
inner_storage_tag
,
remove_cv_t
<
remove_reference_t
<
T
>>>
;
template
<
class
R
,
class
F
>
struct
storage_access
:
F
{
using
type
=
R
;
};
template
<
class
R
,
class
F
>
constexpr
storage_access
<
R
,
F
>
make_storage_access
(
F
f
)
{
return
{{
f
}};
}
template
<
class
Slicer
,
class
F
>
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
{
...
@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis()
...
@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis()
template
<
class
Input
,
index_int
Axis
>
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
template
<
class
Derived
>
struct
reducer_base
{
template
<
class
T
>
__device__
auto
make_inner_slice
(
T
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
;
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
}
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
,
[[
maybe_unused
]]
Ts
&&
...
xs
)
const
{
MIGRAPHX_ASSERT
(
get_size
(
x
)
==
get_size
(
xs
...));
return
get_size
(
x
);
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
.
rsize
();
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
t
.
size
();
}
}
template
<
class
F
>
__device__
auto
inner_sliced
(
F
f
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
get_size
(
xs
...),
make_inner_slice
(
xs
)...);
};
}
template
<
class
T
>
static
__device__
typename
T
::
type
&
decl_inner_storage
(
const
T
&
);
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
using
result_type
=
decltype
(
f
(
decl_inner_storage
(
xs
)...));
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
if
constexpr
(
is_void
<
result_type
>
{})
{
derived
.
inner_void_impl
(
f
,
n
,
xs
...);
}
else
{
return
derived
.
template
inner_impl
<
result_type
>(
f
,
n
,
xs
...);
}
});
}
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
return
derived
.
reduce_impl
(
op
,
init
,
read
,
n
,
xs
...);
});
}
template
<
class
Op
,
class
T
>
__device__
auto
reduce
(
Op
op
,
T
init
)
const
{
return
this
->
reduce
(
op
,
init
,
op
::
id
{});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
f
();
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
using
reduce_type
=
decltype
(
derived
.
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
struct
block
struct
block
{
{
template
<
class
Slicer
>
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
{
index
idx
;
index
idx
;
Slicer
slice
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
T
,
index_int
N
,
class
Size
>
struct
inner_storage
:
inner_storage_tag
{
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
T
;
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
array
<
T
,
N
>
arr
;
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
constexpr
Size
rsize
()
const
{
return
{};
}
});
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
const
{
return
arr
[
d
];
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
{
return
arr
[
d
];
}
};
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
block_reduce
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
});
}
}
...
@@ -215,31 +354,26 @@ struct block
...
@@ -215,31 +354,26 @@ struct block
f
();
f
();
}
}
template
<
class
F
>
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner
(
F
f
)
const
__device__
void
inner
_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
}
template
<
class
Input
>
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
constexpr
auto
elements
(
)
const
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
using
max_iterations
=
decltype
(
idx
.
max_local_stride_iterations
(
n
));
using
value_type
=
typename
Input
::
type
;
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
storage
;
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
}
template
<
class
Output
,
class
F
>
template
<
class
Output
,
class
F
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
a92cde80
...
@@ -30,18 +30,19 @@
...
@@ -30,18 +30,19 @@
namespace
migraphx
{
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
__device__
void
softmax
(
Input
input
1
,
Output
output
)
{
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
(
op
::
id
{})(
input1
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const
auto
c
=
vec_at
(
r
.
slice
(
input
)[
0
],
0
);
const
auto
c
=
vec_at
(
r
.
slice
(
input
1
)[
0
],
0
);
#else
#else
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
#endif
#endif
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
return
migraphx
::
convert
<
float
>
(
migraphx
::
exp
(
x
-
c
));
auto
batch_sum
=
})(
input
);
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
c
)
/
batch_sum
;
})(
output
,
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
/
batch_sum
;
})(
output
,
exp_in
);
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
a92cde80
...
@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
...
@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_nothrow_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_nothrow_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_trivially_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_trivially_constructible
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
class
T
>
template
<
class
T
>
struct
remove_reference
struct
remove_reference
{
{
...
@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
...
@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template
<
class
T
>
template
<
class
T
>
using
add_pointer_t
=
typename
add_pointer
<
T
>::
type
;
using
add_pointer_t
=
typename
add_pointer
<
T
>::
type
;
template
<
class
T
>
struct
is_void
:
is_same
<
void
,
remove_cv_t
<
T
>>
{
};
template
<
class
...
Ts
>
template
<
class
...
Ts
>
struct
common_type
;
struct
common_type
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment