Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
276184b2
Commit
276184b2
authored
Jun 30, 2022
by
Paul
Browse files
Load into registers first
parent
c56d6e9e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
24 deletions
+75
-24
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+11
-0
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+12
-9
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+52
-15
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
276184b2
...
@@ -53,6 +53,17 @@ struct index
...
@@ -53,6 +53,17 @@ struct index
return
blockDim
.
x
;
// NOLINT
return
blockDim
.
x
;
// NOLINT
}
}
#endif
#endif
template
<
class
N
>
constexpr
auto
max_global_stride_iterations
(
N
n
)
const
{
return
_c
<
1
>
+
n
/
nglobal
();
}
template
<
class
N
>
constexpr
auto
max_local_stride_iterations
(
N
n
)
const
{
return
_c
<
1
>
+
n
/
nlocal
();
}
template
<
class
F
>
template
<
class
F
>
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
__device__
void
global_stride
(
index_int
n
,
F
f
)
const
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
276184b2
...
@@ -22,24 +22,27 @@ __device__ void generic_binary_layernorm(
...
@@ -22,24 +22,27 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT
(
relements
>
0
);
MIGRAPHX_ASSERT
(
relements
>
0
);
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
using
value_type
=
typename
Input1
::
type
;
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
auto
mean
=
[
&
](
auto
f
)
{
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
1
,
auto
x2
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
f
(
x
1
,
x2
)
/
value_type
{
relements
};
return
f
(
x
)
/
value_type
{
relements
};
})(
input
1
,
input2
);
})(
input
);
};
};
// mean(x)
// mean(x)
auto
mean_x
=
mean
(
op
);
auto
mean_x
=
mean
(
op
::
id
{}
);
// mean(m ^ 2)
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x
1
,
auto
x2
)
{
auto
mean_m2
=
mean
([
&
](
auto
x
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
auto
m
=
x
-
mean_x
;
return
m
*
m
;
return
m
*
m
;
});
});
r
.
inner
([
&
](
auto
&
y
,
auto
x
1
,
auto
x2
,
auto
...
xs
)
{
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
auto
m
=
x
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input
1
,
input2
,
inputs
...);
})(
output
,
input
,
inputs
...);
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
276184b2
...
@@ -147,25 +147,35 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
...
@@ -147,25 +147,35 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
}
#endif
#endif
namespace
reduce
{
struct
inner_array_base
{};
template
<
class
Output
,
class
Input
,
class
T
>
template
<
class
Output
,
class
Input
,
class
T
>
constexpr
auto
reduce_slice
(
Input
input
,
T
i
)
constexpr
auto
reduce_slice
(
Input
input
,
T
i
)
{
{
constexpr
auto
lens
=
transform
(
get_shape_c
<
Input
>
{}.
lens
,
if
constexpr
(
is_base_of
<
inner_array_base
,
Input
>
{})
get_shape_c
<
Output
>
{}.
lens
,
{
[](
index_int
x
,
index_int
y
)
->
index_int
{
return
input
;
if
(
x
==
y
)
}
return
1
;
else
return
x
;
{
});
constexpr
auto
lens
=
transform
(
get_shape_c
<
Input
>
{}.
lens
,
;
get_shape_c
<
Output
>
{}.
lens
,
constexpr
auto
s
=
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
[](
index_int
x
,
index_int
y
)
->
index_int
{
MIGRAPHX_ASSERT
((
input
.
get_shape
().
index
(
i
)
+
s
.
element_space
())
<=
if
(
x
==
y
)
input
.
get_shape
().
element_space
());
return
1
;
return
make_tensor_view
(
&
input
[
i
],
s
);
return
x
;
});
;
constexpr
auto
s
=
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
MIGRAPHX_ASSERT
((
input
.
get_shape
().
index
(
i
)
+
s
.
element_space
())
<=
input
.
get_shape
().
element_space
());
return
make_tensor_view
(
&
input
[
i
],
s
);
}
}
}
namespace
reduce
{
template
<
class
Slicer
,
class
F
>
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
{
...
@@ -217,11 +227,38 @@ struct block
...
@@ -217,11 +227,38 @@ struct block
f
();
f
();
}
}
template
<
class
T
,
index_int
N
,
index_int
Stride
,
class
Shape
>
struct
inner_array
:
inner_array_base
{
array
<
T
,
N
>
arr
;
constexpr
Shape
get_shape
()
const
{
return
Shape
{};}
template
<
class
U
>
constexpr
auto
&
operator
[](
U
i
)
const
{
return
arr
[
i
/
Stride
];
}
template
<
class
U
>
constexpr
auto
&
operator
[](
U
i
)
{
return
arr
[
i
/
Stride
];
}
};
template
<
class
F
>
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
__device__
auto
inner
(
F
f
)
const
{
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
using
result_type
=
decltype
(
f
(
x
[
0
],
xs
[
0
]...));
if
constexpr
(
is_same
<
result_type
,
void
>
{})
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
}
else
{
inner_array
<
result_type
,
decltype
(
idx
.
max_local_stride_iterations
(
x
.
get_shape
().
elements
())){},
decltype
(
idx
.
nlocal
()){},
decltype
(
x
.
get_shape
())
>
y
;
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
y
[
j
]
=
f
(
x
[
j
],
xs
[
j
]...);
});
return
y
;
}
});
});
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment