Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
82072168
Commit
82072168
authored
Sep 17, 2023
by
Jing Zhang
Browse files
add is_detected
parent
3cf22191
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
79 deletions
+48
-79
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+7
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+0
-65
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+5
-13
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+34
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+2
-0
No files found.
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
View file @
82072168
...
...
@@ -126,13 +126,19 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
);
if
constexpr
(
is_detected
<
is_tuple
,
decltype
(
dst_bufs
)
>::
value
)
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
);
else
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
));
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
82072168
...
...
@@ -687,70 +687,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
(
as_grid_desc_ak0_m_ak1
[
I0
].
GetLength
(
I0
)
*
as_grid_desc_ak0_m_ak1
[
I0
].
GetLength
(
I2
))
/
KPerBlock
);
#if 1
{
const
auto
a_grid_desc
=
as_grid_desc_ak0_m_ak1
;
const
auto
b_grid_desc
=
bs_grid_desc_bk0_n_bk1
;
const
auto
a_block_copy_step
=
a_block_slice_copy_step
;
const
auto
b_block_copy_step
=
b_block_slice_copy_step
;
const
auto
a_block_desc
=
a_block_desc_ak0_m_ak1
;
const
auto
b_block_desc
=
b_block_desc_bk0_n_bk1
;
const
auto
a_grid_bufs
=
as_grid_buf
;
const
auto
b_grid_bufs
=
bs_grid_buf
;
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_bufs
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_bufs
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
tie
(
a_block_desc
),
tie
(
a_block_buf
));
b_blockwise_copy
.
RunWrite
(
tie
(
b_block_desc
),
tie
(
b_block_buf
));
const
auto
num_loop
=
num_k_block_main_loop
;
// main body
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_bufs
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_bufs
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
tie
(
a_block_desc
),
tie
(
a_block_buf
));
b_blockwise_copy
.
RunWrite
(
tie
(
b_block_desc
),
tie
(
b_block_buf
));
++
k
;
}
while
(
k
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
#else
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
...
...
@@ -770,7 +706,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
#endif
// shuffle C and write out
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
82072168
...
...
@@ -7,18 +7,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include <type_traits>
template
<
typename
T
,
typename
=
void
>
struct
has_vec_len
:
std
::
false_type
{
};
template
<
typename
T
>
struct
has_vec_len
<
T
,
std
::
void_t
<
decltype
(
std
::
declval
<
T
>
().
vec_len
)
>>
:
std
::
true_type
{
};
#include "ck/utility/is_detected.hpp"
namespace
ck
{
...
...
@@ -143,6 +132,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number
<
num
>
{});
}
template
<
typename
T
>
using
has_vec_len
=
decltype
(
std
::
declval
<
T
&>
().
vec_len
());
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
...
...
@@ -167,7 +159,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid
);
});
if
constexpr
(
has_vec_len
<
decltype
(
element_op_
)
>::
value
)
if
constexpr
(
is_detected
<
has_vec_len
,
decltype
(
element_op_
)
>::
value
)
{
constexpr
auto
elem_op_vec_len
=
decltype
(
element_op_
)
::
vec_len
;
...
...
include/ck/utility/is_detected.hpp
0 → 100644
View file @
82072168
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
struct
nonesuch
{
~
nonesuch
()
=
delete
;
nonesuch
(
nonesuch
const
&
)
=
delete
;
void
operator
=
(
nonesuch
const
&
)
=
delete
;
};
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
}
// namespace ck
include/ck/utility/tuple.hpp
View file @
82072168
...
...
@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsTuple
()
{
return
true
;
}
};
template
<
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment