Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6ac41ed8
Commit
6ac41ed8
authored
Feb 07, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/jit-reduce-reg' into ck-gsg
parents
64c4c13b
543e166a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
113 additions
and
3 deletions
+113
-3
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+2
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+94
-1
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+2
-1
test/verify/test_reduce_op_large.cpp
test/verify/test_reduce_op_large.cpp
+13
-0
No files found.
src/targets/gpu/jit/reduce.cpp
View file @
6ac41ed8
...
...
@@ -127,6 +127,8 @@ struct reduce_compiler : compiler<reduce_compiler>
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
6ac41ed8
...
...
@@ -46,8 +46,9 @@ template <index_int Axis,
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
float
eps
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input1
,
Axis
>
()
>
;
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
6ac41ed8
...
...
@@ -128,7 +128,7 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
...
...
@@ -388,6 +388,79 @@ struct block
}
};
struct
block_large
{
template
<
class
Slicer
>
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{
f
};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
block_reduce
(
idx
,
op
,
init
,
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
lane
{
template
<
class
Slicer
>
...
...
@@ -466,6 +539,26 @@ struct lane
}
};
// TODO: Remove these in the future when they can be selected in the compiler class
template
<
index_int
RElements
>
constexpr
auto
pick_block
()
{
using
nlocal
=
decltype
(
index
{}.
max_nlocal
());
if
constexpr
(
RElements
<
nlocal
{}
*
256
)
return
block
{};
else
return
block_large
{};
}
template
<
index_int
RElements
>
using
auto_block
=
decltype
(
pick_block
<
RElements
>
());
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
reduce_elements_with_axis
()
{
constexpr
auto
s
=
get_shape_c
<
Input
>
{};
return
s
.
lens
[
Axis
];
}
}
// namespace reduce
template
<
class
Algo
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
6ac41ed8
...
...
@@ -32,7 +32,8 @@ namespace migraphx {
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input1
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input
,
Axis
>
()
>
;
block
::
template
run
<
reduce
::
with_axis
<
Input
,
Axis
>
>
([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
(
op
::
id
{})(
input1
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const
auto
c
=
vec_at
(
r
.
slice
(
input1
)[
0
],
0
);
...
...
test/verify/test_reduce_op_large.cpp
View file @
6ac41ed8
...
...
@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
return
p
;
};
};
struct
test_large_reduce_mean
:
verify_program
<
test_large_reduce_mean
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
256
*
256
*
16
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
mm
->
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
}},
x
);
return
p
;
};
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment