Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
730c5fff
Commit
730c5fff
authored
Dec 03, 2024
by
coderfeli
Browse files
fix linear
parent
4525c5d7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
123 additions
and
72 deletions
+123
-72
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+28
-28
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+3
-0
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+7
-1
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+5
-0
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+0
-16
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
+39
-0
include/ck_tile/core/tensor/tensor_coordinate.hpp
include/ck_tile/core/tensor/tensor_coordinate.hpp
+10
-0
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+12
-10
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+10
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+7
-5
No files found.
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
730c5fff
...
...
@@ -146,7 +146,7 @@ struct pass_through : public base_transform<1, 1>
//
printf
(
"up_lengths_:"
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
//
printf
(
"}"
);
...
...
@@ -236,17 +236,17 @@ struct pad : public base_transform<1, 1>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"left_pad_length_: "
);
print
(
left_pad_length_
);
print
x
(
left_pad_length_
);
printf
(
", "
);
//
printf
(
"right_pad_length_: "
);
print
(
right_pad_length_
);
print
x
(
right_pad_length_
);
printf
(
"}"
);
}
...
...
@@ -337,12 +337,12 @@ struct left_pad
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"left_pad_length_: "
);
print
(
left_pad_length_
);
print
x
(
left_pad_length_
);
printf
(
"}"
);
}
...
...
@@ -437,12 +437,12 @@ struct right_pad : public base_transform<1, 1>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_pad_length_: "
);
print
(
right_pad_length_
);
print
x
(
right_pad_length_
);
printf
(
"}"
);
}
...
...
@@ -539,12 +539,12 @@ struct embed : public base_transform<1, UpLengths::size()>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"coefficients_: "
);
print
(
coefficients_
);
print
x
(
coefficients_
);
printf
(
"}"
);
}
...
...
@@ -706,12 +706,12 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
//
printf
(
"low_lengths_ "
);
print
(
low_lengths_
);
print
x
(
low_lengths_
);
printf
(
", "
);
//
printf
(
"up_lengths_ "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
"}"
);
}
...
...
@@ -837,17 +837,17 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
//
printf
(
"low_lengths_ "
);
print
(
low_lengths_
);
print
x
(
low_lengths_
);
printf
(
", "
);
//
printf
(
"low_lengths_scan_ "
);
print
(
low_lengths_scan_
);
print
x
(
low_lengths_scan_
);
printf
(
", "
);
//
printf
(
"up_lengths_ "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
"}"
);
}
...
...
@@ -965,12 +965,12 @@ struct unmerge : public base_transform<1, UpLengths::size()>
//
printf
(
"up_lengths_"
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"up_lengths_scan_"
);
print
(
up_lengths_scan_
);
print
x
(
up_lengths_scan_
);
printf
(
"}"
);
}
...
...
@@ -1030,7 +1030,7 @@ struct freeze : public base_transform<1, 0>
//
printf
(
"low_idx_: "
);
print
(
low_idx_
);
print
x
(
low_idx_
);
printf
(
"}"
);
}
...
...
@@ -1098,7 +1098,7 @@ struct insert : public base_transform<0, 1>
printf
(
"insert{"
);
//
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
"}"
);
}
...
...
@@ -1158,7 +1158,7 @@ struct replicate : public base_transform<0, UpLengths::size()>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
"}"
);
}
...
...
@@ -1245,17 +1245,17 @@ struct slice : public base_transform<1, 1>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"slice_begin_: "
);
print
(
slice_begin_
);
print
x
(
slice_begin_
);
printf
(
", "
);
//
printf
(
"slice_end_: "
);
print
(
slice_end_
);
print
x
(
slice_end_
);
printf
(
"}"
);
}
// namespace ck
...
...
@@ -1335,7 +1335,7 @@ struct modulo : public base_transform<1, 1>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
"}"
);
}
...
...
@@ -1431,7 +1431,7 @@ struct xor_t : public base_transform<2, 2>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
...
...
@@ -1516,12 +1516,12 @@ struct offset : public base_transform<1, 1>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"offset_length_: "
);
print
(
offset_length_
);
print
x
(
offset_length_
);
printf
(
"}"
);
}
...
...
@@ -1602,7 +1602,7 @@ struct indexing : public base_transform<1, 1>
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
print
x
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
...
...
include/ck_tile/core/config.hpp
View file @
730c5fff
...
...
@@ -230,3 +230,6 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
printx
(
T
a
=
{})
{
a
.
print
();}
\ No newline at end of file
include/ck_tile/core/container/array.hpp
View file @
730c5fff
...
...
@@ -52,7 +52,13 @@ struct array
data
[
i
]
=
vlast
;
}
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"array{size: %d, data: "
,
size
());
for
(
index_t
i
=
0
;
i
<
size
();
i
++
)
{
printf
(
"%d,"
,
int
(
get
(
i
)));
}
}
template
<
typename
Y
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
Y
,
value_type
>
||
std
::
is_constructible_v
<
Y
,
value_type
>>>
...
...
include/ck_tile/core/container/tuple.hpp
View file @
730c5fff
...
...
@@ -195,6 +195,11 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
using
base
=
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
;
CK_TILE_HOST_DEVICE
constexpr
tuple
()
=
default
;
CK_TILE_HOST_DEVICE
void
print
()
const
{
// printf("tuple{size: %d, data: [", size());
// ((printf("%d ", Is)), ...);
// printf("]}");
}
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
std
::
initializer_list
<
U
>
us
)
:
base
(
us
)
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
730c5fff
...
...
@@ -50,22 +50,6 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
...
...
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
View file @
730c5fff
...
...
@@ -89,6 +89,45 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor&
remove_cvref_t
<
decltype
(
top_dim_ids
)
>>
{
idx_hidden
};
}
// template <typename Adaptor, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate_debug(const Adaptor& adaptor,
// const TopIndex& idx_top)
// {
// static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
// "wrong! # of dimension inconsistent");
// constexpr index_t ntransform = Adaptor::get_num_of_transform();
// constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
// constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
// constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
// multi_index<ndim_hidden> idx_hidden;
// // idx_hidden.print();
// // initialize visible index
// set_container_subset(idx_hidden, top_dim_ids, idx_top);
// // calculate hidden index
// static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
// auto itran = itran_p1 - number<1>{};
// const auto& tran = adaptor.get_transforms().at(itran);
// tran.print();
// constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
// constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
// const auto idx_up = get_container_subset(idx_hidden, dims_up);
// multi_index<dims_low.size()> idx_low;
// tran.calculate_lower_index(idx_low, idx_up);
// set_container_subset(idx_hidden, dims_low, idx_low);
// idx_hidden.print();
// });
// return tensor_adaptor_coordinate<ndim_hidden,
// remove_cvref_t<decltype(bottom_dim_ids)>,
// remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
// }
template
<
bool
JudgeDoTransforms
=
true
,
typename
Adaptor
,
typename
AdaptorCoord
,
...
...
include/ck_tile/core/tensor/tensor_coordinate.hpp
View file @
730c5fff
...
...
@@ -66,6 +66,16 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tens
remove_cvref_t
<
decltype
(
TensorDesc
::
get_top_dimension_hidden_ids
())
>>
{
adaptor_coord
};
}
// template <typename TensorDesc, typename TopIndex>
// CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate_debug(const TensorDesc& tensor_desc,
// const TopIndex& idx_top)
// {
// const auto adaptor_coord = make_tensor_adaptor_coordinate_debug(tensor_desc, idx_top);
// return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
// remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
// adaptor_coord};
// }
template
<
bool
JudgeDoTransforms
=
true
,
typename
TensorDesc
,
typename
TensorCoord
,
typename
Index
>
CK_TILE_HOST_DEVICE
constexpr
void
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
730c5fff
...
...
@@ -440,6 +440,13 @@ struct tile_window_linear
// we directly use BottomTensorView transform to compute the offset, in case padding
auto
bottom_tensor_coord
=
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
// if(threadIdx.x == 0) {
// bottom_tensor_coord =
// make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
// printf("off00 %d %d\n",i_access, bottom_tensor_coord.get_offset() );
// bottom_tensor_coord.get_hidden_index().print();
// bottom_tensor_coord.get_index().print();
// }
return
bottom_tensor_coord
.
get_offset
();
}
else
...
...
@@ -468,14 +475,16 @@ struct tile_window_linear
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
typename
DistributedTensor
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
...
...
@@ -518,13 +527,7 @@ struct tile_window_linear
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load
(
dst_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
dst_tensor
;
}
...
...
@@ -547,8 +550,7 @@ struct tile_window_linear
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
730c5fff
...
...
@@ -232,8 +232,8 @@ struct BlockGemmARegBRegCRegV2
CK_TILE_DEVICE
static
void
PrefetchLds
(
const
BlockWindow
&
block_window
,
BlockTensor
&
block_tensor
)
{
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
//
load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
//
load_tile(block_tensor, make_tile_window(block_window, tileDist));
load_tile
(
block_tensor
,
make_tile_window_linear
(
block_window
,
tileDist
));
}
// C = A * B
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
730c5fff
...
...
@@ -260,12 +260,13 @@ struct GemmPipelineAGmemBGmemCRegV1
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
BLdsTile
b_block_tile0
;
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_window0
,
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_window1
,
ALdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_window0
,
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_window1
,
BLdsTileDistr
);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b
_ld
s
_window0
,
b_block_tile0
);
load_tile
(
b_block_tile0
,
b_lds
_ld_window0
);
// LDS write 1
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
...
...
@@ -285,9 +286,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
{
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b
_ld
s
_window1
,
b_block_tile1
);
load_tile
(
b_block_tile1
,
b_lds
_ld_window1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
...
...
@@ -300,7 +300,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b
_ld
s
_window0
,
b_block_tile0
);
load_tile
(
b_block_tile0
,
b_lds
_ld_window0
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
...
...
@@ -319,7 +319,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b
_ld
s
_window1
,
b_block_tile1
);
load_tile
(
b_block_tile1
,
b_lds
_ld_window1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
...
...
@@ -329,7 +329,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile
(
a_block_tile0
,
a_lds_ld_window0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b
_ld
s
_window0
,
b_block_tile0
);
load_tile
(
b_block_tile0
,
b_lds
_ld_window0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
//1
...
...
@@ -344,7 +344,7 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds
();
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile
(
a_block_tile1
,
a_lds_ld_window1
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b
_ld
s
_window1
,
b_block_tile1
);
load_tile
(
b_block_tile1
,
b_lds
_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
730c5fff
...
...
@@ -59,14 +59,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
kMPerBlock
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}
),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
8
>
{}
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
@@ -88,8 +88,10 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
// make_tuple(make_pass_through_transform(kNPerBlock),
// make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
8
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment