Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5a2d93d4
Commit
5a2d93d4
authored
Nov 28, 2024
by
coderfeli
Browse files
revert code
parent
6a07464b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
13 deletions
+113
-13
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+16
-0
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+52
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+44
-12
No files found.
include/ck_tile/core/tensor/load_tile.hpp
View file @
5a2d93d4
...
@@ -46,6 +46,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
...
@@ -46,6 +46,22 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
return
tile_window
.
load
(
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
5a2d93d4
...
@@ -453,6 +453,58 @@ struct tile_window_linear
...
@@ -453,6 +453,58 @@ struct tile_window_linear
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
typename
DistributedTensor
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
5a2d93d4
...
@@ -210,7 +210,7 @@ struct BlockGemmARegBRegCRegV2
...
@@ -210,7 +210,7 @@ struct BlockGemmARegBRegCRegV2
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
return
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
return
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
// load_tile
_raw
(block_tensor, make_tile_window_linear
_raw
(block_window, tileDist));
// load_tile(block_tensor, make_tile_window_linear(block_window, tileDist));
// return;
// return;
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
5a2d93d4
...
@@ -256,8 +256,40 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -256,8 +256,40 @@ struct GemmPipelineAGmemBGmemCRegV1
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
{}));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
{}));
ALdsTile
a_block_tile0
;
ALdsTile
a_block_tile0
;
BLdsTile
b_block_tile0
;
BLdsTile
b_block_tile0
;
load_tile
(
a_block_tile0
,
make_tile_window
(
a_lds_window0
,
ALdsTileDistr
{}));
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
load_tile
(
b_block_tile0
,
make_tile_window
(
b_lds_window0
,
BLdsTileDistr
{}));
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
// if (threadIdx.x == 64) {
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f, %f; ", type_convert<float>(a_block_tile0(i_j_idx)), type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
...
@@ -274,8 +306,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -274,8 +306,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
// ping
{
{
block_sync_lds
();
block_sync_lds
();
load_tile
(
a_block_tile1
,
make_tile_window
(
a_lds_window1
,
ALdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
load_tile
(
b_block_tile1
,
make_tile_window
(
b_lds_window1
,
BLdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
...
@@ -286,8 +318,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -286,8 +318,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// pong
// pong
{
{
block_sync_lds
();
block_sync_lds
();
load_tile
(
a_block_tile0
,
make_tile_window
(
a_lds_window0
,
ALdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
load_tile
(
b_block_tile0
,
make_tile_window
(
b_lds_window0
,
BLdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
b_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
...
@@ -303,8 +335,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -303,8 +335,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
// 3
{
{
block_sync_lds
();
block_sync_lds
();
load_tile
(
a_block_tile1
,
make_tile_window
(
a_lds_window1
,
ALdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
load_tile
(
b_block_tile1
,
make_tile_window
(
b_lds_window1
,
BLdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
a_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
LocalPrefill
(
b_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
...
@@ -312,8 +344,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -312,8 +344,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
// 2
{
{
block_sync_lds
();
block_sync_lds
();
load_tile
(
a_block_tile0
,
make_tile_window
(
a_lds_window0
,
ALdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window0
,
a_block_tile0
);
load_tile
(
b_block_tile0
,
make_tile_window
(
b_lds_window0
,
BLdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
}
//1
//1
...
@@ -324,8 +356,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -324,8 +356,8 @@ struct GemmPipelineAGmemBGmemCRegV1
}
else
{
}
else
{
{
{
block_sync_lds
();
block_sync_lds
();
load_tile
(
a_block_tile1
,
make_tile_window
(
a_lds_window1
,
ALdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_lds_window1
,
a_block_tile1
);
load_tile
(
b_block_tile1
,
make_tile_window
(
b_lds_window1
,
BLdsTileDistr
{})
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_lds_window1
,
b_block_tile1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
}
// 2
// 2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment