Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d807d05e
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "864ed34f565ea5c066778c3c1aa708903ec22be4"
Commit
d807d05e
authored
Sep 19, 2023
by
Jing Zhang
Browse files
add scaleAdd_vec4 example
parent
ca0f9579
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
9 deletions
+30
-9
example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp
example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp
+29
-8
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+1
-1
No files found.
example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp
View file @
d807d05e
...
@@ -41,21 +41,42 @@ using BLayout = Col;
...
@@ -41,21 +41,42 @@ using BLayout = Col;
using
DLayout
=
Row
;
using
DLayout
=
Row
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
struct
Add
struct
Add
Scale
{
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
ck
::
half
2
_t
&
a
,
const
ck
::
half
2
_t
&
a0
,
const
ck
::
half
2
_t
&
a1
)
const
operator
()(
ck
::
half
4
_t
&
a
,
const
ck
::
half
4
_t
&
a0
,
const
ck
::
half
4
_t
&
a1
)
const
{
{
a
=
a0
+
a1
;
const
auto
a0_v_t
=
ck
::
vector_type
<
ck
::
half_t
,
4
>
{
a0
};
const
auto
a1_v_t
=
ck
::
vector_type
<
ck
::
half_t
,
4
>
{
a1
};
auto
r_v_t
=
ck
::
vector_type
<
ck
::
half_t
,
4
>
{};
r_v_t
.
AsType
<
ck
::
half_t
>
()(
I0
)
=
scale
*
(
a0_v_t
.
AsType
<
ck
::
half_t
>
()[
I0
]
+
a1_v_t
.
AsType
<
ck
::
half_t
>
()[
I0
]);
r_v_t
.
AsType
<
ck
::
half_t
>
()(
I1
)
=
scale
*
(
a0_v_t
.
AsType
<
ck
::
half_t
>
()[
I1
]
+
a1_v_t
.
AsType
<
ck
::
half_t
>
()[
I1
]);
r_v_t
.
AsType
<
ck
::
half_t
>
()(
I2
)
=
scale
*
(
a0_v_t
.
AsType
<
ck
::
half_t
>
()[
I2
]
+
a1_v_t
.
AsType
<
ck
::
half_t
>
()[
I2
]);
r_v_t
.
AsType
<
ck
::
half_t
>
()(
I3
)
=
scale
*
(
a0_v_t
.
AsType
<
ck
::
half_t
>
()[
I3
]
+
a1_v_t
.
AsType
<
ck
::
half_t
>
()[
I3
]);
a
=
r_v_t
.
AsType
<
ck
::
half4_t
>
()[
I0
];
}
}
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
ck
::
half_t
&
a
,
const
ck
::
half_t
&
a0
,
const
ck
::
half_t
&
a1
)
const
operator
()(
ck
::
half_t
&
a
,
const
ck
::
half_t
&
a0
,
const
ck
::
half_t
&
a1
)
const
{
{
a
=
a0
+
a1
;
a
=
scale
*
(
a0
+
a1
)
;
}
}
static
constexpr
ck
::
index_t
vec_len
=
4
;
static
constexpr
ck
::
index_t
vec_len
=
4
;
float
scale
=
1.0
;
};
};
struct
AlphaBetaAdd
struct
AlphaBetaAdd
...
@@ -76,7 +97,7 @@ struct AlphaBetaAdd
...
@@ -76,7 +97,7 @@ struct AlphaBetaAdd
float
beta_
;
float
beta_
;
};
};
using
AElementOp
=
Add
;
using
AElementOp
=
Add
Scale
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
AlphaBetaAdd
;
using
CDEElementOp
=
AlphaBetaAdd
;
...
@@ -248,7 +269,7 @@ int main(int argc, char* argv[])
...
@@ -248,7 +269,7 @@ int main(int argc, char* argv[])
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{
0.2
};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
...
@@ -312,14 +333,14 @@ int main(int argc, char* argv[])
...
@@ -312,14 +333,14 @@ int main(int argc, char* argv[])
BDataType
,
BDataType
,
CShuffleDataType
,
CShuffleDataType
,
AccDataType
,
AccDataType
,
BElementOp
,
PassThrough
,
BElementOp
,
BElementOp
,
PassThrough
>
;
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
b_element_op
,
b_element_op
,
PassThrough
{});
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
PassThrough
{}
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
d807d05e
...
@@ -133,7 +133,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
...
@@ -133,7 +133,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
}
template
<
typename
T
>
template
<
typename
T
>
using
has_vec_len
=
decltype
(
std
::
declval
<
T
&>
().
vec_len
()
);
using
has_vec_len
=
decltype
(
std
::
declval
<
T
&>
().
vec_len
);
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment