Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5eec2aef
Commit
5eec2aef
authored
Dec 27, 2022
by
Anthony Chang
Browse files
clean up
parent
8079a1b0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
14 deletions
+9
-14
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+6
-6
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+0
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+1
-3
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
5eec2aef
...
...
@@ -40,7 +40,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmInstance
1
;
using
DeviceGemmInstance
=
DeviceGemmInstance
0
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
5eec2aef
...
...
@@ -234,11 +234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO
ANT
: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
#if 0
// TODO
ANT
: use alias
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
...
...
@@ -329,7 +329,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO
ANT
: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
...
...
@@ -410,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO
ANT
: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const
index_t
num_dims
=
NumDimG
+
NumDimN
+
NumDimO
;
...
...
@@ -738,7 +738,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_grid_desc_g_m_n_
,
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
{
// TODO
ANT
: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
...
...
@@ -950,7 +950,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
false
;
}
// TODO
ANT
: Check if tensor specialization & strides mismatch
// TODO: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5eec2aef
...
...
@@ -209,9 +209,6 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
// TODO ANT: is this necessary?
// block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups
const
index_t
idx_ksplit
=
block_1d_id
/
(
M0
*
N0
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5eec2aef
...
...
@@ -54,8 +54,7 @@ template <typename SrcData,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
// TODO ANT: can we generalize to allow sub-wg slice transfer? need
// to distinguish what dimensions are spread across waves
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
...
...
@@ -137,7 +136,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData
v
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
5eec2aef
...
...
@@ -148,7 +148,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
32
)
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment