Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b37cb71f
Commit
b37cb71f
authored
Oct 16, 2019
by
Wen-Heng (Jack) Chung
Browse files
Enable bwd wrw
parent
c5143bca
Changes
26
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
706 additions
and
1551 deletions
+706
-1551
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+167
-847
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+124
-62
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+395
-636
composable_kernel/include/utility/config_amd.hpp.in
composable_kernel/include/utility/config_amd.hpp.in
+1
-4
driver/include/tensor.hpp
driver/include/tensor.hpp
+15
-0
driver/src/driver.cpp
driver/src/driver.cpp
+4
-2
No files found.
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
b37cb71f
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
b37cb71f
...
@@ -7,98 +7,160 @@
...
@@ -7,98 +7,160 @@
namespace
ck
{
namespace
ck
{
template
<
class
Float
,
class
Matrix
>
template
<
typename
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
{
for
(
index_t
i
=
0
;
i
<
Matrix
::
NRow
();
++
i
)
for
(
index_t
i
=
0
;
i
<
Matrix
::
NRow
();
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
Matrix
::
NCol
();
++
j
)
for
(
index_t
j
=
0
;
j
<
Matrix
::
NCol
();
++
j
)
{
{
const
index_t
id
=
Matrix
::
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
id
=
Matrix
::
CalculateOffset
(
i
,
j
);
p_thread
[
id
]
=
Float
(
0
);
p_thread
[
id
]
=
Float
(
0
);
}
}
}
}
}
}
template
<
class
Float
,
template
<
typename
SrcMatrix
,
class
SrcMatrix
,
typename
DstMatrix
,
class
DstMatrix
,
index_t
NSliceRow
,
index_t
NRow
,
index_t
NSliceCol
,
index_t
NCol
,
index_t
DataPerAccess
>
index_t
DataPerRead
>
struct
ThreadwiseMatrixSliceCopy
__device__
void
threadwise_matrix_copy
(
SrcMatrix
,
const
Float
*
__restrict__
p_src
,
DstMatrix
,
Float
*
__restrict__
p_dst
,
Sequence
<
NRow
,
NCol
>
,
Number
<
DataPerRead
>
)
{
{
static_assert
(
NCol
%
DataPerRead
==
0
,
"wrong! should be NCol % == DataPerRead == 0"
);
__device__
constexpr
ThreadwiseMatrixSliceCopy
()
{
static_assert
(
SrcMatrix
::
RowStride
()
%
DataPerAccess
==
0
&&
DstMatrix
::
RowStride
()
%
DataPerAccess
==
0
,
"wrong! wrong alignment"
);
static_assert
(
NSliceCol
%
DataPerAccess
==
0
,
"wrong! should be NSliceCol % DataPerAccess == 0"
);
}
constexpr
auto
src_mtx
=
SrcMatrix
{};
template
<
typename
Data
>
constexpr
auto
dst_mtx
=
DstMatrix
{};
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
{
for
(
index_t
j
=
0
;
j
<
NCol
;
j
+=
DataPerRead
)
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NSliceRow
;
++
i
)
{
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
for
(
index_t
j
=
0
;
j
<
NSliceCol
;
j
+=
DataPerAccess
)
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
{
const
index_t
src_index
=
SrcMatrix
::
CalculateOffset
(
i
,
j
);
const
index_t
dst_index
=
DstMatrix
::
CalculateOffset
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}
}
}
}
}
;
template
<
class
MatrixA
,
// C += transpose(A) * B
class
MatrixB
,
// Element of matrix can be vectorized data
class
MatrixC
,
template
<
typename
MatrixA
,
typename
MatrixB
,
typename
MatrixC
>
bool
TransA
,
struct
ThreadwiseGemmTransANormalBNormalC
bool
TransB
,
bool
TransC
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
threadwise_gemm
(
MatrixA
,
integral_constant
<
bool
,
TransA
>
,
const
FloatA
*
__restrict__
p_a_thread
,
MatrixB
,
integral_constant
<
bool
,
TransB
>
,
const
FloatB
*
__restrict__
p_b_thread
,
MatrixC
,
integral_constant
<
bool
,
TransC
>
,
FloatC
*
__restrict__
p_c_thread
)
{
{
static_if
<
TransA
&&
(
!
TransB
)
&&
(
!
TransC
)
>
{}([
&
](
auto
)
{
__device__
constexpr
ThreadwiseGemmTransANormalBNormalC
()
constexpr
auto
a_mtx
=
MatrixA
{};
{
constexpr
auto
b_mtx
=
MatrixB
{};
static_assert
(
MatrixA
::
NRow
()
==
MatrixB
::
NRow
()
&&
MatrixA
::
NCol
()
==
MatrixC
::
NRow
()
&&
constexpr
auto
c_mtx
=
MatrixC
{};
MatrixB
::
NCol
()
==
MatrixC
::
NCol
(),
"wrong!"
);
}
constexpr
index_t
M
=
c_mtx
.
NRow
();
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
constexpr
index_t
N
=
c_mtx
.
NCol
();
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
constexpr
index_t
K
=
a_mtx
.
NRow
();
// A is transposed
{
constexpr
index_t
M
=
MatrixC
::
NRow
();
constexpr
index_t
N
=
MatrixC
::
NCol
();
constexpr
index_t
K
=
MatrixA
::
NRow
();
// A is transposed
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
for
(
index_t
i
=
0
;
i
<
M
;
++
i
)
for
(
index_t
m
=
0
;
m
<
M
;
++
m
)
{
{
for
(
index_t
j
=
0
;
j
<
N
;
++
j
)
for
(
index_t
n
=
0
;
n
<
N
;
++
n
)
{
{
const
index_t
aindex
=
a_mtx
.
GetOffsetFromMultiIndex
(
k
,
i
);
// A is transposed
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
bindex
=
MatrixB
::
CalculateOffset
(
k
,
n
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
cindex
=
MatrixC
::
CalculateOffset
(
m
,
n
);
p_c
_thread
[
cindex
]
+=
math
::
inner_product_with_conversion
<
FloatC
>
{}(
p_c
[
cindex
]
+=
p_a_thread
[
aindex
],
p_b
_thread
[
bindex
]);
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
aindex
],
p_b
[
bindex
]);
}
}
}
}
}
}
}).
Else
([
&
](
auto
fwd
)
{
}
// not implemented
static_assert
(
fwd
(
false
),
"wrong! support for this config is not implemented"
);
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
});
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
}
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
index_t
M
=
MatrixC
::
NRow
();
constexpr
index_t
N
=
MatrixC
::
NCol
();
constexpr
index_t
K
=
MatrixA
::
NRow
();
// A is transposed
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
m
=
0
;
m
<
M
;
++
m
)
{
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
static_if
<
N
==
2
>
{}([
&
](
auto
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
__outer_product_1x2
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_c
[
cindex_0
],
p_c
[
cindex_1
]);
});
static_if
<
N
==
4
>
{}([
&
](
auto
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
const
index_t
bindex_2
=
MatrixB
::
CalculateOffset
(
k
,
2
);
const
index_t
bindex_3
=
MatrixB
::
CalculateOffset
(
k
,
3
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
const
index_t
cindex_2
=
MatrixC
::
CalculateOffset
(
m
,
2
);
const
index_t
cindex_3
=
MatrixC
::
CalculateOffset
(
m
,
3
);
__outer_product_1x4
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_b
[
bindex_2
],
p_b
[
bindex_3
],
p_c
[
cindex_0
],
p_c
[
cindex_1
],
p_c
[
cindex_2
],
p_c
[
cindex_3
]);
});
}
}
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr
bool
has_amd_asm
=
is_same
<
FloatC
,
float
>
{}
&&
((
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{})
||
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
static_if
<
has_amd_asm
>
{}([
&
](
auto
fwd
)
{
Run_amd_asm
(
p_a
,
p_b
,
fwd
(
p_c
));
}).
Else
([
&
](
auto
)
{
Run_source
(
p_a
,
p_b
,
p_c
);
});
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
b37cb71f
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/config_amd.hpp.in
View file @
b37cb71f
...
@@ -4,12 +4,9 @@
...
@@ -4,12 +4,9 @@
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#include "bfloat16_dev.hpp"
#define CK_DEVICE_BACKEND_AMD 1
#define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INLINE_ASM 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
namespace ck {
namespace ck {
...
...
driver/include/tensor.hpp
View file @
b37cb71f
This diff is collapsed.
Click to expand it.
driver/src/driver.cpp
View file @
b37cb71f
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment