Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6c270303
Commit
6c270303
authored
Nov 27, 2024
by
dummycoderfe
Browse files
change pipelines to v4. compile ok
parent
c808fa65
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
602 additions
and
164 deletions
+602
-164
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+2
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+226
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp
...emm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp
+33
-0
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+0
-30
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+337
-131
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+4
-3
No files found.
include/ck_tile/ops/gemm.hpp
View file @
6c270303
...
@@ -6,8 +6,10 @@
...
@@ -6,8 +6,10 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
0 → 100644
View file @
6c270303
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV2
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockTensor
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
// check ABC-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
a_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
ABlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"A distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
b_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
BBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"B distribution is wrong!"
);
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A Block window
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B block tensor
BWarpTensor
b_warp_tensor
;
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistribution
()
{
// M->N Warp
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
return
a_block_dstr
;
// return make_static_distributed_tensor<ADataType>(a_block_dstr);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBBlockDistribution
()
{
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
b_block_dstr
=
make_static_tile_distribution
(
b_block_dstr_encode
);
return
b_block_dstr
;
// return make_static_distributed_tensor<BDataType>(b_block_dstr);
}
// Prefetch lds
template
<
typename
BlockWindowTmp
,
typename
BlockTensor
>
CK_TILE_DEVICE
static
auto
PrefetchLds
(
const
BlockWindowTmp
&
block_window
,
BlockTensor
&
block_tensor
)
{
auto
tileDist
=
BlockTensor
::
get_tile_distribution
();
//.get_static_tile_distribution_encoding()
return
load_tile
(
block_tensor
,
make_tile_window
(
block_window
,
tileDist
));
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockTensor
&
b_block_tensor
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor
,
b_block_tensor
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_default_policy.hpp
0 → 100644
View file @
6c270303
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBRegCRegV2
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBRegCRegV2DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K16
{},
2
,
2
);
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16
{},
2
,
2
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
6c270303
...
@@ -148,36 +148,6 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -148,36 +148,6 @@ struct BlockGemmASmemBSmemCRegV1
});
});
});
});
});
});
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// // hot loop:
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// // read A warp tensor from A block window
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// // read B warp tensor from B Block window
// // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// // read C warp tensor from C block tensor
// CWarpTensor c_warp_tensor;
// c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// // warp GEMM
// WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter));
// // write C warp tensor into C block tensor
// c_block_tensor.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensor.get_thread_buffer());
// });
// });
// });
}
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
6c270303
...
@@ -39,13 +39,14 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -39,13 +39,14 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
return
integer_
divide_ceil
(
return
integer_
least_multiple
(
sizeof
(
ADataType
)
*
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
)
*
2
+
16
+
integer_least_multiple
(
sizeof
(
BDataType
)
*
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
2
;
}
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
@@ -53,6 +54,23 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -53,6 +54,23 @@ struct GemmPipelineAGmemBGmemCRegV1
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
kKPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
@@ -75,23 +93,23 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -75,23 +93,23 @@ struct GemmPipelineAGmemBGmemCRegV1
"wrong!"
);
"wrong!"
);
// A tile in LDS
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
16
;
constexpr
index_t
b_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
BDataType
)
*
b_lds_block_desc
.
get_element_space_size
(),
16
);
ADataType
*
p_a_lds0
=
reinterpret_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
p_a_lds1
=
reinterpret_cast
<
ADataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
);
// B tile in LDS
// B tile in LDS
BDataType
*
p_b_lds
=
static
_cast
<
BDataType
*>
(
BDataType
*
p_b_lds
0
=
reinterpret
_cast
<
BDataType
*>
(
reinterpret_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
*
2
);
static_cast
<
void
*>
(
static
_cast
<
char
*>
(
p_
smem
)
+
a
_lds_block_space_size_aligned
)
)
;
BDataType
*
p_b_lds1
=
reinterpret_cast
<
BDataType
*>
(
reinterpret
_cast
<
char
*>
(
p_
b_lds0
)
+
b
_lds_block_space_size_aligned
);
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
auto
a_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds0
,
a_lds_block_desc
);
auto
b_lds_block0
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds0
,
b_lds_block_desc
);
auto
a_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds1
,
a_lds_block_desc
);
auto
b_lds_block1
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds1
,
b_lds_block_desc
);
// A DRAM tile window for load
// A DRAM tile window for load
auto
a_copy_dram_window
=
auto
a_copy_dram_window
=
...
@@ -101,8 +119,10 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -101,8 +119,10 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
auto
a_store_lds_window0
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
a_store_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B DRAM tile window for load
// B DRAM tile window for load
auto
b_copy_dram_window
=
auto
b_copy_dram_window
=
...
@@ -112,143 +132,144 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -112,143 +132,144 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
auto
b_store_lds_window0
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
b_store_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// A LDS tile for block GEMM
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
auto
a_load_lds_window0
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
a_lds_block0
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
a_load_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
auto
b_load_lds_window0
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
b_lds_block0
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
auto
b_load_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
auto
c_block_tile
=
Policy
::
template
BlockGemm
<
Problem
>
::
MakeCBlockTile
();
// a b register tile
auto
a_prefetch_tile0
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
a_prefetch_tile1
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
());
auto
b_prefetch_tile0
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
auto
b_prefetch_tile1
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
());
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
// prefetch
// prefetch
// global read 0
// global read 0
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
auto
b_block_tile
=
load_tile
(
b_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
{
// move to 1
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
LocalPrefill
(
a_store_lds_window0
,
a_global_load_tile
,
a_element_func
);
{
LocalPrefill
(
b_store_lds_window0
,
b_global_load_tile
,
b_element_func
);
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
block_sync_lds
();
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
// global read 1
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
}
// local prefetch 0
else
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_prefetch_tile0
);
{
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_prefetch_tile0
);
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 1
LocalPrefill
(
a_store_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window1
,
b_global_load_tile
,
b_element_func
);
// global read 2
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
// LDS write 0
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
// __syncthreads();
// if (threadIdx.x == 0) {
// for (int j = 0; j < 256; j++) {
// for(int i = 0; i < 32; i++) {
// int ik0 = i /8;
// int ik1 = i % 8;
// printf("%f,", type_convert<float>(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256]));
// }
// printf("\n");
// }
// }
index_t
iCounter
=
num_loop
-
1
;
index_t
iCounter
=
num_loop
-
1
;
while
(
iCounter
>
0
)
while
(
iCounter
>
2
)
{
// ping
{
{
// global read i + 1
a_block_tile
=
load_tile
(
a_copy_dram_window
);
b_block_tile
=
load_tile
(
b_copy_dram_window
);
block_sync_lds
();
block_sync_lds
();
// GEMM i
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_prefetch_tile1
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_prefetch_tile1
);
LocalPrefill
(
a_store_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window0
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
}
__builtin_amdgcn_sched_barrier
(
0
);
// pong
{
block_sync_lds
();
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_prefetch_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_prefetch_tile0
);
LocalPrefill
(
a_store_lds_window1
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window1
,
b_global_load_tile
,
b_element_func
);
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
);
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
);
block_gemm
(
c_block_tile
,
a_prefetch_tile1
,
b_prefetch_tile1
);
// move to i + 2
}
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
// LDS write i + 1
iCounter
-=
2
;
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
}
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
//tail 3
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
(
iCounter
==
1
)
{
// 3
{
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
block_sync_lds
();
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_prefetch_tile1
);
store_tile
(
b_copy_lds_window
,
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_prefetch_tile1
);
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
LocalPrefill
(
a_store_lds_window0
,
a_global_load_tile
,
a_element_func
);
LocalPrefill
(
b_store_lds_window0
,
b_global_load_tile
,
b_element_func
);
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
// 2
{
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
block_sync_lds
();
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window0
,
a_prefetch_tile0
);
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window0
,
b_prefetch_tile0
);
block_gemm
(
c_block_tile
,
a_prefetch_tile1
,
b_prefetch_tile1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
//1
iCounter
--
;
{
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
}
}
//tail 2
// tail
}
else
{
{
{
block_sync_lds
();
block_sync_lds
();
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
a_load_lds_window1
,
a_prefetch_tile1
);
// GEMM num_loop - 1
Policy
::
template
BlockGemm
<
Problem
>
::
PrefetchLds
(
b_load_lds_window1
,
b_prefetch_tile1
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_prefetch_tile0
,
b_prefetch_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
// 2
{
block_gemm
(
c_block_tile
,
a_prefetch_tile1
,
b_prefetch_tile1
);
}
}
}
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
return
c_block_tile
;
return
c_block_tile
;
}
}
...
@@ -268,4 +289,189 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -268,4 +289,189 @@ struct GemmPipelineAGmemBGmemCRegV1
}
}
};
};
// __device__ static constexpr auto HotLoopScheduler()
// {
// // schedule
// constexpr auto num_ds_read_inst =
// HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
// constexpr auto num_ds_write_inst =
// HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
// ;
// constexpr auto num_buffer_load_inst =
// HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
// ;
// constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
// constexpr auto num_issue = num_buffer_load_inst;
// static_for<0, num_issue, 1>{}([&](auto i) {
// ignore = i;
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(
// 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
// });
// }
// CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
// {
// constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(number<0>{});
// constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(number<1>{});
// constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(number<2>{});
// constexpr index_t WaveSize = 64;
// constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(number<0>{});
// constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(number<1>{});
// constexpr index_t A_LDS_Read_Width = KPerXDL;
// constexpr index_t B_LDS_Read_Width = KPerXDL;
// constexpr index_t A_Buffer_Load_Inst_Num =
// MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
// constexpr index_t B_Buffer_Load_Inst_Num =
// NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
// constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t A_LDS_Read_Inst_Num =
// WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t B_LDS_Read_Inst_Num =
// WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
// constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
// (BlockSize / WaveSize) /
// (MPerXDL * NPerXDL * KPerXDL);
// // A/B split schedule
// // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
// constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
// ? A_LDS_Read_Inst_Num
// : A_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
// ? B_LDS_Read_Inst_Num
// : B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
// constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
// constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
// constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
// constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
// constexpr auto ds_read_a_issue_cycle =
// A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
// constexpr auto ds_read_b_issue_cycle =
// B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
// constexpr auto ds_read_a_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
// constexpr auto ds_read_b_mfma_rate =
// (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
// constexpr auto num_dsread_b_mfma =
// (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// // stage 1
// // Separate this part?
// // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// // sizeof(ComputeDataType) /
// // sizeof(BDataType)
// // ? sizeof(ComputeDataType) /
// // sizeof(ADataType) : sizeof(ComputeDataType)
// // / sizeof(BDataType);
// constexpr auto num_mfma_stage1 =
// num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
// constexpr auto num_mfma_per_issue =
// num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
// static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
// });
// static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
// ignore = i;
// static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
// ignore = idswrite;
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(
// 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
// });
// // stage 2
// static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
// ds_read_a_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
// if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
// ds_read_b_mfma_rate)
// {
// __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
// }
// else
// {
// __builtin_amdgcn_sched_group_barrier(
// 0x100,
// num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
// 0); // DS read
// }
// __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
6c270303
...
@@ -11,6 +11,9 @@ namespace ck_tile {
...
@@ -11,6 +11,9 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV2DefaultPolicy
;
template
<
typename
Problem
>
using
BlockGemm
=
BlockGemmARegBRegCRegV2
<
Problem
,
BlockGemmPolicy
>
;
#if 0
#if 0
// 2d
// 2d
...
@@ -472,9 +475,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -472,9 +475,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
;
return
BlockGemm
<
Problem
>
{};
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment