Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
20ddaeba
Commit
20ddaeba
authored
Apr 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c5f1cdf7
43879b89
Changes
236
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2171 additions
and
0 deletions
+2171
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp
...emm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp
+36
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp
...mm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp
+46
-0
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp
...le/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp
+26
-0
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+213
-0
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
...mm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
+38
-0
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+55
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+200
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+251
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
+218
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+18
-0
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+25
-0
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+18
-0
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+112
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+471
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+379
-0
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+65
-0
No files found.
Too many changes to show.
To preserve performance only
236 of 236+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockGemmARegBSmemCRegV2CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmARegBSmemCRegV2
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBSmemCRegV2DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
// FIXME
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
{},
4
,
1
);
#endif
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Problem Description for BlockGemmASmemBSmemCRegV1
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmASmemBSmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
>
struct
BlockGemmASmemBSmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockWindowTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockWindowTmp
&
a_block_window_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindowTmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindowTmp
::
DataType
>
&&
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
MPerBlockPerIter
=
MPerBlock
/
MIterPerWarp
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iMWarp
=
get_warp_id
()
/
NWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
a_warp_window_tmp
),
KIterPerWarp
>
,
MIterPerWarp
>
a_warp_windows
;
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
a_warp_windows
(
mIter
)(
kIter
)
=
a_warp_window_tmp
;
move_tile_window
(
a_warp_windows
(
mIter
)(
kIter
),
{
mIter
*
MPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
const
auto
a_warp_tensor
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window_tmp
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockGemmASmemBSmemCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmASmemBSmemCRegV1DefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
static_assert(kBlockSize % get_warp_size() == 0, "wrong!");
constexpr index_t NumWarp = kBlockSize / get_warp_size();
if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 &&
kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
else
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
}
#else
return
make_tuple
(
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
#endif
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
CDataType
,
float
>
)
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
BlockGemmPipelineAGmemBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kMPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kNPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kKPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
// prefetch
// global read 0
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
auto
b_block_tile
=
load_tile
(
b_copy_dram_window
);
{
// move to 1
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
index_t
iCounter
=
num_loop
-
1
;
do
{
// global read i + 1
a_block_tile
=
load_tile
(
a_copy_dram_window
);
b_block_tile
=
load_tile
(
b_copy_dram_window
);
block_sync_lds
();
// GEMM i
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
// move to i + 2
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
// LDS write i + 1
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
iCounter
--
;
}
while
(
iCounter
>
0
);
// tail
{
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
return
c_block_tile
;
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
operator
()(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif
1
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc
;
}
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kNPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kNPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
#elif 1
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
a_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc_m_k
;
}
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
#endif
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
#endif
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
;
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
BlockGemmPipelineAGmemBGmemCRegV2
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kMPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kNPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kKPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>();
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
// prefetch
// global read 0
auto
a_block_tile
=
load_tile
(
a_copy_dram_window
);
auto
b_block_tile
=
load_tile
(
b_copy_dram_window
);
{
// move to 1
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// global read 1
a_block_tile
=
load_tile
(
a_copy_dram_window
);
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
// global read 1
b_block_tile
=
load_tile
(
b_copy_dram_window
);
}
index_t
iCounter
=
num_loop
-
2
;
do
{
block_sync_lds
();
// GEMM i
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
// move to i + 2
move_tile_window
(
a_copy_dram_window
,
{
0
,
kKPerBlock
});
move_tile_window
(
b_copy_dram_window
,
{
0
,
kKPerBlock
});
// LDS write i + 1
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// global read i + 2
a_block_tile
=
load_tile
(
a_copy_dram_window
);
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
// global read i + 2
b_block_tile
=
load_tile
(
b_copy_dram_window
);
iCounter
--
;
}
while
(
iCounter
>
0
);
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
// LDS write num_loop - 1
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
return
c_block_tile
;
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
operator
()(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kMPerTile
,
index_t
kNPerTile
,
index_t
kKPerTile
>
struct
TileGemmShape
{
static
constexpr
index_t
kM
=
kMPerTile
;
static
constexpr
index_t
kN
=
kNPerTile
;
static
constexpr
index_t
kK
=
kKPerTile
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
namespace
ck_tile
{
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
2
,
swizzle_factor
>>
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
namespace
ck_tile
{
template
<
typename
WarpGemmAttributeMfmaImpl_
>
struct
WarpGemmAtrributeMfma
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
using
ADataType
=
typename
Impl
::
ADataType
;
using
BDataType
=
typename
Impl
::
BDataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
typename
Impl
::
AVecType
;
using
BVecType
=
typename
Impl
::
BVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
Impl
{}(
c_vec
,
a_vec
,
b_vec
);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
return
Impl
{}(
a_vec
,
b_vec
);
}
};
template
<
typename
WarpGemmAttributeMfmaImpl_
,
index_t
kKIter
>
struct
WarpGemmAtrributeMfmaIterateK
{
static_assert
(
kKIter
>
0
,
"wrong!"
);
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
using
ADataType
=
typename
Impl
::
ADataType
;
using
BDataType
=
typename
Impl
::
BDataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
ext_vector_t
<
ADataType
,
vector_traits
<
typename
Impl
::
AVecType
>::
vector_size
*
kKIter
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kM
;
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kCNLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
constexpr
auto
I0
=
number
<
0
>
{};
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
// c = a * b
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_a
>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
]);
// c += a * b
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_a
>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_b
>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
return
c_vec
;
}
};
template
<
typename
WarpGemmAttributeMfmaImpl_
>
struct
WarpGemmAtrributeMfmaTransposedCDistribution
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
using
ADataType
=
typename
Impl
::
BDataType
;
using
BDataType
=
typename
Impl
::
ADataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
typename
Impl
::
BVecType
;
using
BVecType
=
typename
Impl
::
AVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
// swap A and B
return
Impl
{}(
b_vec
,
a_vec
);
}
};
template
<
typename
WarpGemmAttributeMfmaImpl_
>
struct
WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
using
ADataType
=
typename
Impl
::
BDataType
;
using
BDataType
=
typename
Impl
::
ADataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
typename
Impl
::
BVecType
;
using
BVecType
=
typename
Impl
::
AVecType
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kABKPerLane
*
Impl
::
kABKLane
*
2
),
Impl
::
kABKLane
,
2
,
Impl
::
kABKPerLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
tuple
<
sequence
<
2
,
1
,
1
,
1
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
2
,
1
,
3
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
/
2
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
*
2
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
// swap A and B
Impl
{}(
c_vec
,
b_vec
,
a_vec
);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
// swap A and B
return
Impl
{}(
b_vec
,
a_vec
);
}
};
template
<
typename
WarpGemmAttributeMfmaImpl_
,
index_t
kKIter
>
struct
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
// swap A and B
using
ADataType
=
typename
Impl
::
BDataType
;
using
BDataType
=
typename
Impl
::
ADataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
ext_vector_t
<
ADataType
,
vector_traits
<
typename
Impl
::
AVecType
>::
vector_size
*
kKIter
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
// swap A and B, value and type
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
constexpr
auto
I0
=
number
<
0
>
{};
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
// swap A and B, value and type
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_b
&>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
]);
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
});
return
c_vec
;
}
};
template
<
typename
WarpGemmAttributeMfmaImpl_
,
index_t
kKIter
,
index_t
SFactor_
=
2
>
struct
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
// swap A and B
using
ADataType
=
typename
Impl
::
BDataType
;
using
BDataType
=
typename
Impl
::
ADataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
ext_vector_t
<
ADataType
,
vector_traits
<
typename
Impl
::
AVecType
>::
vector_size
*
kKIter
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
vector_traits
<
typename
Impl
::
BVecType
>::
vector_size
*
kKIter
>
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
#if 0
using BWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
Impl::kABKLane,
2,
Impl::kABKPerLane>,
sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
tuple<sequence<2, 1, 1, 1, 1>>,
tuple<sequence<0, 0, 2, 1, 3>>,
sequence<2>,
sequence<1>>;
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kCNLane>,
sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>;
#else
// TODO: more test not only 32x32
using
BWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
Impl
::
kCMLane
,
SFactor
,
Impl
::
kCM1PerLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
,
1
,
1
,
1
>>
,
tuple
<
sequence
<
0
,
0
,
2
,
1
,
3
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
using
CWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kCNLane
>
,
sequence
<
Impl
::
kCM0PerLane
/
SFactor
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
*
SFactor
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
2
,
2
>
,
sequence
<
0
,
2
>>
;
#endif
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
// swap A and B, value and type
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
using
buf_b
=
thread_buffer
<
typename
Impl
::
BVecType
,
kKIter
>
;
constexpr
auto
I0
=
number
<
0
>
{};
// swap A and B, value and type
auto
c_vec
=
Impl
{}(
reinterpret_cast
<
const
buf_b
&>
(
b_vec
).
template
get_as
<
typename
Impl
::
BVecType
>()[
I0
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
).
template
get_as
<
typename
Impl
::
AVecType
>()[
I0
]);
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
reinterpret_cast
<
const
buf_b
&>
(
b_vec
)
.
template
get_as
<
typename
Impl
::
BVecType
>()[
iKIter
],
reinterpret_cast
<
const
buf_a
&>
(
a_vec
)
.
template
get_as
<
typename
Impl
::
AVecType
>()[
iKIter
]);
});
return
c_vec
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// FP16
struct
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
16
>
;
static
constexpr
index_t
kM
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
2
;
static
constexpr
index_t
kCNLane
=
32
;
static
constexpr
index_t
kCM0PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
struct
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
fp16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
16
;
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
4
;
static
constexpr
index_t
kCNLane
=
16
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// Bf16
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
16
>
;
static
constexpr
index_t
kM
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
8
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
2
;
static
constexpr
index_t
kCNLane
=
32
;
static
constexpr
index_t
kCM0PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
CVecType
c_vec
{
0.
f
};
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x4bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
struct
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
BVecType
=
ext_vector_t
<
bf16_t
,
4
>
;
using
CVecType
=
ext_vector_t
<
float
,
4
>
;
static
constexpr
index_t
kM
=
16
;
static
constexpr
index_t
kN
=
16
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMLane
=
16
;
static
constexpr
index_t
kBNLane
=
16
;
static
constexpr
index_t
kABKLane
=
4
;
static
constexpr
index_t
kABKPerLane
=
4
;
static
constexpr
index_t
kCMLane
=
4
;
static
constexpr
index_t
kCNLane
=
16
;
static
constexpr
index_t
kCM0PerLane
=
1
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__)
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__)
CVecType
c_vec
{
0.
f
};
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
k
)
{
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x8bf16
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
4
>&>
(
a_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
4
>&>
(
b_vec
)
.
template
get_as
<
ext_vector_t
<
bf16_t
,
2
>
>
()[
number
<
k
>
{}],
c_vec
,
0
,
0
,
0
);
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
// FP8
template
<
typename
AType_
,
typename
BType_
>
struct
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
using
ADataType
=
AType_
;
using
BDataType
=
BType_
;
using
CDataType
=
float
;
using
AVecType
=
ext_vector_t
<
ADataType
,
8
>
;
using
BVecType
=
ext_vector_t
<
BDataType
,
8
>
;
using
CVecType
=
ext_vector_t
<
CDataType
,
16
>
;
static
constexpr
index_t
kM
=
32
;
static
constexpr
index_t
kN
=
32
;
static
constexpr
index_t
kK
=
16
;
static
constexpr
index_t
kAMLane
=
32
;
static
constexpr
index_t
kBNLane
=
32
;
static
constexpr
index_t
kABKLane
=
2
;
static
constexpr
index_t
kABKPerLane
=
8
;
static
constexpr
index_t
kCMLane
=
2
;
static
constexpr
index_t
kCNLane
=
32
;
static
constexpr
index_t
kCM0PerLane
=
4
;
static
constexpr
index_t
kCM1PerLane
=
4
;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
c_vec
,
0
,
0
,
0
);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
fp8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
fp8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
else
if
constexpr
(
std
::
is_same_v
<
ADataType
,
bf8_t
>
&&
std
::
is_same_v
<
BDataType
,
bf8_t
>
)
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8
(
bit_cast
<
long
>
(
a_vec
),
bit_cast
<
long
>
(
b_vec
),
CVecType
{
0.
f
},
0
,
0
,
0
));
#elif defined(__gfx908__) || defined(__gfx90a__)
CVecType
c_vec
{
0.
f
};
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
a_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
ADataType
,
8
>&>
(
a_vec
)
.
template
get_as
<
ADataType
>()[
number
<
k
>
{}]);
float
b_f32
=
type_convert
<
float
>
(
reinterpret_cast
<
const
thread_buffer
<
BDataType
,
8
>&>
(
b_vec
)
.
template
get_as
<
BDataType
>()[
number
<
k
>
{}]);
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
};
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
;
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
bf8_t
>
;
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
fp8_t
>
;
using
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
=
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
bf8_t
,
bf8_t
>
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace
ck_tile
{
namespace
impl
{
template
<
typename
AType
,
typename
BType
,
typename
CType
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
struct
WarpGemmMfmaDispatcher
;
// clang-format off
// fp16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
// clang-format on
}
// namespace impl
template
<
typename
AType
,
typename
BType
,
typename
CType
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
>::
Type
;
}
// namespace ck_tile
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment