Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
44e87b4e
Commit
44e87b4e
authored
May 23, 2022
by
rocking
Browse files
Implement reduction meand and reduction square mean
parent
ac543313
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
5 deletions
+10
-5
example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp
...ple/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp
+10
-5
No files found.
example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp
View file @
44e87b4e
...
@@ -45,7 +45,7 @@ using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>;
...
@@ -45,7 +45,7 @@ using D1ReduceOp = ck::reduce::Add<ReduceAccDataType>;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
UnaryIdenticElementOp
=
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
fals
e
>
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
tru
e
>
;
using
UnarySquareElementOp
=
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
...
@@ -181,6 +181,9 @@ int main(int argc, char* argv[])
...
@@ -181,6 +181,9 @@ int main(int argc, char* argv[])
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()));
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()));
auto
dxs_in_element_op
=
DxsInElementOp
{};
auto
dxs_out_element_op
=
DxsOutElementOp
{
M
,
M
};
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmReduceInstance
{};
auto
gemm
=
DeviceGemmReduceInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -197,8 +200,8 @@ int main(int argc, char* argv[])
...
@@ -197,8 +200,8 @@ int main(int argc, char* argv[])
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
D
xs
InE
lement
Op
{}
,
d
xs
_in_e
lement
_op
,
D
xs
OutE
lement
Op
{}
);
d
xs
_out_e
lement
_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -256,12 +259,14 @@ int main(int argc, char* argv[])
...
@@ -256,12 +259,14 @@ int main(int argc, char* argv[])
float
d0_val
=
0
;
float
d0_val
=
0
;
float
d1_val
=
0
;
float
d1_val
=
0
;
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
0
>
{}
)
(
d0_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
1
>
{}
)
(
d1_val
,
c_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
}
dxs_out_element_op
(
ck
::
Number
<
0
>
{})(
d0_acc
,
d0_acc
);
dxs_out_element_op
(
ck
::
Number
<
1
>
{})(
d1_acc
,
d1_acc
);
d0_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
);
d0_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
);
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
);
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment