Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e511bb78
Commit
e511bb78
authored
Nov 26, 2024
by
coderfeli
Browse files
lds a,b ok
parent
d51f4e52
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
352 additions
and
159 deletions
+352
-159
CMakeLists.txt
CMakeLists.txt
+0
-4
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+4
-0
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+4
-3
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+150
-78
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+91
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+32
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+70
-70
No files found.
CMakeLists.txt
View file @
e511bb78
...
...
@@ -503,10 +503,6 @@ include_directories(BEFORE
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
add_compile_options
(
-Weverything
)
endif
()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
e511bb78
...
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
#
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
e511bb78
...
...
@@ -117,6 +117,10 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
// ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
// ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
// ck_tile::FillConstant<ADataType>
{
1.f
}
(a_m_k);
// ck_tile::FillConstant<BDataType>
{
1.f
}
(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
...
...
include/ck_tile/host/fill.hpp
View file @
e511bb78
...
...
@@ -211,15 +211,16 @@ struct FillNormalDistributionIntegerValue
template
<
typename
T
>
struct
FillMonotonicSeq
{
T
init_value_
{
0
};
T
init_value_
{
-
1024
};
T
step_
{
1
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
init_value_
]()
mutable
{
T
step_start
=
init_value_
;
std
::
generate
(
first
,
last
,
[
&
,
n
=
init_value_
]()
mutable
{
auto
tmp
=
n
;
n
+=
step_
;
if
(
n
>
step_start
+
2047
)
{
step_start
+=
step_
;
n
=
step_start
;}
return
tmp
;
});
}
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
e511bb78
...
...
@@ -42,9 +42,6 @@ struct BlockGemmASmemBSmemCRegV1
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock);
// }
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
...
...
@@ -56,12 +53,12 @@ struct BlockGemmASmemBSmemCRegV1
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
MPerBlockPerIter
=
MPerBlock
/
MIterPerWarp
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
//
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
//
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
//
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
//
const index_t iMWarp = get_warp_id() / NWarp;
//
const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
...
...
@@ -69,91 +66,69 @@ struct BlockGemmASmemBSmemCRegV1
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
a_warp_window_tmp
),
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
MPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// construct B-warp-window
make_tuple
(
MPerBlock
,
KPerBlock
),
{
0
,
0
},
Policy
::
template
MakeALDSTileDistribution
<
Problem
>());
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
make_tuple
(
NPerBlock
,
KPerBlock
),
{
0
,
0
},
Policy
::
template
MakeBLDSTileDistribution
<
Problem
>());
auto
a_block_tensor
=
load_tile
(
a_warp_window_tmp
);
auto
b_block_tensor
=
load_tile
(
b_warp_window_tmp
);
// if (threadIdx.x == 0) {
// printf("0\n");
// constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f %f,", type_convert<float>(a_block_tensor(i_j_idx)), type_convert<float>(b_block_tensor(i_j_idx)));
// });
// printf("\n");
// });
// }
// __syncthreads();
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
BWarpDstr
=
typename
WG
::
BWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
BWarpTensor
=
typename
WG
::
BWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
// read A warp tensor from A Block window
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read B warp tensor from B block tensor
BWarpTensor
b_warp_tensor
;
b_warp_tensor
.
get_thread_buffer
()
=
b_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
nIter
,
kIter
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
...
...
@@ -173,6 +148,36 @@ struct BlockGemmASmemBSmemCRegV1
});
});
});
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// // hot loop:
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// // read A warp tensor from A block window
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// // read B warp tensor from B Block window
// // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// // read C warp tensor from C block tensor
// CWarpTensor c_warp_tensor;
// c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// // warp GEMM
// WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter));
// // write C warp tensor into C block tensor
// c_block_tensor.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensor.get_thread_buffer());
// });
// });
// });
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
...
...
@@ -217,5 +222,72 @@ struct BlockGemmASmemBSmemCRegV1
return
c_block_tensor
;
}
};
// construct A-warp-window
// auto a_warp_window_tmp = make_tile_window(
// a_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kM>{}, number<WG::kK>{}),
// a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
// make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
// {a_warp_window_tmp}};
// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
// MIterPerWarp>
// a_warp_windows;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// move_tile_window(a_warp_windows(mIter)(kIter),
// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
// construct B-warp-window
// auto b_warp_window_tmp = make_tile_window(
// b_block_window.get_bottom_tensor_view(),
// make_tuple(number<WG::kN>{}, number<WG::kK>{}),
// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
// make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
// #if 0 // FIXME: using array will cause register spill
// array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
// {b_warp_window_tmp}};
// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
// {
// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
// {
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// }
// }
// #else
// statically_indexed_array<
// statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
// move_tile_window(b_warp_windows(nIter)(kIter),
// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
// });
// });
// #endif
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
e511bb78
...
...
@@ -40,7 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K
8
{},
2
,
2
);
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K
16TransposedCDistribution
{},
2
,
2
);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif
}
...
...
@@ -55,6 +55,96 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
static_assert
(
false
,
"Unsupported data type configuration for GEMM warp execution."
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALDSTileDistribution
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
static_assert
(
false
,
"Unsupported tensor_layout right now."
);
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr
index_t
K2
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K1
=
2
;
constexpr
index_t
K0
=
KPerBlock
/
K1
/
K2
;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr
index_t
M2
=
32
;
// MPERXDL
constexpr
index_t
M1
=
2
;
//MWAVE
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
2
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
false
,
"Unsupported shape right now."
);
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLDSTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
static_assert
(
false
,
"Unsupported tensor_layout right now."
);
}
else
{
//Number<krepeat>{}, Number<klane>{}, Number<Kpack>{}))),
constexpr
index_t
K2
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K1
=
2
;
constexpr
index_t
K0
=
KPerBlock
/
K1
/
K2
;
//Number<mrepeat>{}, Number<mwaves>{}, Number<MPerXdl>{}))),
constexpr
index_t
N2
=
32
;
// MPERXDL
constexpr
index_t
N1
=
2
;
//MWAVE
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
2
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
false
,
"Unsupported shape right now."
);
}
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
e511bb78
...
...
@@ -133,7 +133,16 @@ struct GemmPipelineAGmemBGmemCRegV1
// global read 0
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
auto
b_block_tile
=
load_tile
(
b_copy_dram_window
);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
{
// move to 1
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
...
...
@@ -170,7 +179,17 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
// __syncthreads();
// if (threadIdx.x == 0) {
// for (int j = 0; j < 256; j++) {
// for(int i = 0; i < 32; i++) {
// int ik0 = i /8;
// int ik1 = i % 8;
// printf("%f,", type_convert<float>(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256]));
// }
// printf("\n");
// }
// }
index_t
iCounter
=
num_loop
-
1
;
while
(
iCounter
>
0
)
{
...
...
@@ -219,6 +238,17 @@ struct GemmPipelineAGmemBGmemCRegV1
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// if(abs(type_convert<float>(c_block_tile(i_j_idx)) - 32) > 0.1)
// printf("%d %f,", threadIdx.x, type_convert<float>(c_block_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
return
c_block_tile
;
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
e511bb78
...
...
@@ -54,7 +54,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
(
kMPerBlock
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
...
...
@@ -77,7 +77,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kNPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
(
kNPerBlock
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
...
...
@@ -130,74 +130,74 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
#elif 1
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
a_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc_m_k
;
}
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
//
template <typename Problem>
//
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
//
{
//
using namespace ck_tile;
//
using ADataType = remove_cvref_t<typename Problem::ADataType>;
//
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
//
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
//
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
//
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
//
number<kKPerBlock>{});
//
constexpr index_t kK1 = 16 / sizeof(ADataType);
//
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
//
a_lds_block_desc_d1_d2_d3,
//
make_tuple(
//
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
//
make_pass_through_transform(2)),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}));
//
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
//
a_lds_block_desc_d4_d5_d6,
//
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
//
make_pass_through_transform(kKPerBlock)),
//
make_tuple(sequence<0, 1>{}, sequence<2>{}),
//
make_tuple(sequence<0>{}, sequence<1>{}));
//
return a_lds_block_desc_m_k;
//
}
//
//
fake XOR
//
template <typename Problem>
//
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
//
{
//
using namespace ck_tile;
//
using BDataType = remove_cvref_t<typename Problem::BDataType>;
//
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
//
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
//
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
//
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
//
number<kKPerBlock>{});
//
constexpr index_t kK1 = 16 / sizeof(BDataType);
//
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
//
b_lds_block_desc_d1_d2_d3,
//
make_tuple(
//
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
//
make_pass_through_transform(2)),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}),
//
make_tuple(sequence<0, 2>{}, sequence<1>{}));
//
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
//
b_lds_block_desc_d4_d5_d6,
//
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
//
make_pass_through_transform(kKPerBlock)),
//
make_tuple(sequence<0, 1>{}, sequence<2>{}),
//
make_tuple(sequence<0>{}, sequence<1>{}));
//
return b_lds_block_desc_n_k;
//
}
#endif
template
<
typename
Problem
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment