Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
81b26528
Commit
81b26528
authored
Nov 22, 2021
by
Chao Liu
Browse files
added bias add; worked around compiler issues
parent
4f2c8bce
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
181 additions
and
42 deletions
+181
-42
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+5
-5
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp
+59
-13
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
+1
-1
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+5
-0
example/2_gemm_xdl_bias_add/gemm_xdl_bias_add.cpp
example/2_gemm_xdl_bias_add/gemm_xdl_bias_add.cpp
+111
-23
No files found.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
81b26528
...
...
@@ -50,7 +50,7 @@ template <typename SrcData,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
Src
ElementwiseOperation
,
typename
Dst
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -72,9 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
Src
ElementwiseOperation
src
_element_op
)
const
Dst
ElementwiseOperation
&
dst
_element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
src
_element_op_
{
src
_element_op
}
dst
_element_op_
{
dst
_element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
...
...
@@ -201,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// apply element-wise operation and type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src
_element_op_
(
src_buf
[
Number
<
src_offset
>
{}]));
type_convert
<
DstData
>
(
dst
_element_op_
(
src_buf
[
Number
<
src_offset
>
{}]));
});
const
bool
is_dst_valid
=
...
...
@@ -378,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord
dst_coord_
;
Src
ElementwiseOperation
src
_element_op_
;
const
Dst
ElementwiseOperation
dst
_element_op_
;
};
// namespace ck
// Assume:
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp
View file @
81b26528
...
...
@@ -32,7 +32,7 @@ template <typename SrcData,
typename
DstDesc
,
typename
Dst0Desc
,
// this is really one of sources, but it has same shape as DstDesc
typename
Dst1Desc
,
// this is really one of sources, but it has same shape as DstDesc
typename
Src
ElementwiseOperation
,
typename
Dst
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
...
...
@@ -60,11 +60,11 @@ struct ThreadwiseTensorSliceTransfer_v1r4
const
Dst0Desc
&
dst0_desc
,
const
Dst1Desc
&
dst1_desc
,
const
Index
&
dst_slice_origin_idx
,
const
Src
ElementwiseOperation
src
_element_op
)
const
Dst
ElementwiseOperation
&
dst
_element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
dst0_coord_
(
make_tensor_coordinate
(
dst0_desc
,
dst_slice_origin_idx
)),
dst1_coord_
(
make_tensor_coordinate
(
dst1_desc
,
dst_slice_origin_idx
)),
src
_element_op_
{
src
_element_op
}
dst
_element_op_
{
dst
_element_op
}
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
...
...
@@ -258,15 +258,45 @@ struct ThreadwiseTensorSliceTransfer_v1r4
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
// apply element-wise operation and type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_element_op_
(
src_buf
[
Number
<
src_offset
>
{}]));
});
// load dst0 and dst1 and apply elementwise operation
{
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
static_assert
(
DstScalarPerVector
==
1
,
"wrong!"
);
// copy data from src_buf into dst_vector_src_data
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
dst_data_idx
);
const
SrcData
src_v
=
src_buf
[
Number
<
src_offset
>
{}];
// load dst0 and dst1
const
bool
is_dst0_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst0_desc
,
dst0_coord_
);
const
bool
is_dst1_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst1_desc
,
dst1_coord_
);
const
DstData
dst0_v
=
dst0_buf
.
template
Get
<
DstData
>(
dst0_coord_
.
GetOffset
(),
is_dst0_valid
);
const
DstData
dst1_v
=
dst1_buf
.
template
Get
<
DstData
>(
dst1_coord_
.
GetOffset
(),
is_dst1_valid
);
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
// apply element-wise operation in SrcData type
const
SrcData
dst_v
=
dst_element_op_
(
src_v
,
type_convert
<
SrcData
>
(
dst0_v
),
type_convert
<
SrcData
>
(
dst1_v
));
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
type_convert
<
DstData
>
(
dst_v
);
#else
// apply element-wise operation in DstData type
const
DstData
dst_v
=
dst_element_op_
(
src_v
,
dst0_v
,
dst1_v
);
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
dst_v
;
#endif
}
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
...
...
@@ -327,11 +357,27 @@ struct ThreadwiseTensorSliceTransfer_v1r4
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dim_access_order
[
i
]]);
// dst0
move_tensor_coordinate
(
dst0_desc
,
dst0_coord_
,
dst0_forward_steps
[
dim_access_order
[
i
]]);
// dst1
move_tensor_coordinate
(
dst1_desc
,
dst1_coord_
,
dst1_forward_steps
[
dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dim_access_order
[
i
]]);
// dst0
move_tensor_coordinate
(
dst0_desc
,
dst0_coord_
,
dst0_backward_steps
[
dim_access_order
[
i
]]);
// dst1
move_tensor_coordinate
(
dst1_desc
,
dst1_coord_
,
dst1_backward_steps
[
dim_access_order
[
i
]]);
}
}
});
...
...
@@ -469,7 +515,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
DstCoord
dst_coord_
;
Dst0Coord
dst0_coord_
;
Dst1Coord
dst1_coord_
;
Src
ElementwiseOperation
src
_element_op_
;
const
Dst
ElementwiseOperation
dst
_element_op_
;
};
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
View file @
81b26528
...
...
@@ -810,7 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
SrcElementwiseOperation
src_element_op_
;
const
SrcElementwiseOperation
src_element_op_
;
};
}
// namespace ck
...
...
composable_kernel/include/utility/config.hpp
View file @
81b26528
...
...
@@ -136,6 +136,11 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
// workaround for register spill due to compiler issue, when casting type between fp32 and fp16
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1
#endif
namespace
ck
{
enum
InMemoryDataOperationEnum_t
...
...
example/2_gemm_xdl_bias_add/gemm_xdl_bias_add.cpp
View file @
81b26528
...
...
@@ -14,10 +14,6 @@
#include "device_base.hpp"
#include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp"
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
struct
PassThrough
{
template
<
typename
T
>
...
...
@@ -27,17 +23,60 @@ struct PassThrough
}
};
struct
Relu
// GEMM Bias Add:
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
// C1 presents in memory as 1d vector, but is represented as 2D matrix C1[m, n], with stride = 0 in
// the "n" dimension
//
// alpha * v0 + beta * v1 + gamma * v2
// v0 is from C matrix
// v1 is from residual matrix
// v2 is from bias vector
struct
BiasAdd
{
float
alpha
=
0.1
;
#if 1
// correct result
// no scratch memory, good VGPR allocation (59)
// good perf (101Tflops)
template
<
typename
T1
,
typename
T2
>
__host__
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
// compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end
float
a
=
T1
(
0.2
)
*
v1
+
T2
(
0.3
)
*
v2
;
float
b
=
a
+
float
(
0.1
)
*
v0
;
// ReLU
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
return
b
;
}
#elif 0
// correct result
// some scratch memory (68), large VGPR usage (126)
// very little perf drop (101Tflops)
__host__
__device__
constexpr
auto
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
{
T
tmp
=
alpha
*
v
;
return
tmp
>
0
?
tmp
:
0
;
return
float
(
0.1
)
*
v0
+
ck
::
half_t
(
0.2
)
*
v1
+
ck
::
half_t
(
0.3
)
*
v2
;
}
#elif 0
// correct result
// some scratch memory (68 dword)
// some perf drop (94Tflops)
// fp64 instructions are used
__host__
__device__
constexpr
auto
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
{
return
0.1
*
v0
+
0.2
*
v1
+
0.3
*
v2
;
}
#elif 1
// wrong result
// lots of scratch memory
// huge perf drop
__host__
__device__
constexpr
auto
operator
()(
float
v0
,
ck
::
half_t
v1
,
ck
::
half_t
v2
)
const
{
return
float
(
0.1
)
*
v0
+
float
(
0.2
)
*
v1
+
float
(
0.3
)
*
v2
;
}
#endif
};
template
<
typename
ADataType
,
...
...
@@ -125,13 +164,49 @@ struct DeviceGemmInstance<float,
// clang-format on
};
template
<
typename
AType
,
typename
BType
,
typename
CType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
static
void
host_verify
(
const
Tensor
<
AType
>&
a_m_k
,
const
Tensor
<
BType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
const
Tensor
<
CType
>&
c0_m_n
,
const
Tensor
<
CType
>&
c1_m_n
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
}
c_m_n
(
m
,
n
)
=
c_element_op
(
v
,
static_cast
<
const
double
>
(
c0_m_n
(
m
,
n
)),
static_cast
<
const
double
>
(
c1_m_n
(
m
,
n
)));
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
c_m_n
.
mDesc
.
GetLengths
()[
0
],
c_m_n
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
4
)
if
(
argc
!=
10
)
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
...
...
@@ -140,18 +215,24 @@ int main(int argc, char* argv[])
const
int
nrepeat
=
std
::
stoi
(
argv
[
3
]);
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
M
=
std
::
stoi
(
argv
[
4
])
;
ck
::
index_t
N
=
std
::
stoi
(
argv
[
5
])
;
ck
::
index_t
K
=
std
::
stoi
(
argv
[
6
])
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideA
=
std
::
stoi
(
argv
[
7
])
;
ck
::
index_t
StrideB
=
std
::
stoi
(
argv
[
8
])
;
ck
::
index_t
StrideC
=
std
::
stoi
(
argv
[
9
])
;
// matrix data type
#if 1
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
#else
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
#endif
// matrix layout
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -219,6 +300,8 @@ int main(int argc, char* argv[])
c0_m_n_device_buf
.
ToDevice
(
c0_m_n
.
mData
.
data
());
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
auto
c_element_op
=
BiasAdd
{};
// do GEMM
auto
gemm
=
typename
DeviceGemmInstance
<
ADataType
,
BDataType
,
...
...
@@ -228,7 +311,7 @@ int main(int argc, char* argv[])
CLayout
,
PassThrough
,
PassThrough
,
Relu
>::
type
{};
decltype
(
c_element_op
)
>::
type
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -244,7 +327,7 @@ int main(int argc, char* argv[])
StrideC
,
PassThrough
{},
PassThrough
{},
Relu
{}
);
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -270,8 +353,13 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
Relu
{});
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
host_verify
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c0_m_n
,
c1_m_n
,
PassThrough
{},
PassThrough
{},
c_element_op
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment