Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e6bb1dd7
Unverified
Commit
e6bb1dd7
authored
Jul 19, 2024
by
Po Yen Chen
Committed by
GitHub
Jul 19, 2024
Browse files
Merge branch 'develop' into feature/check-window-lengths
parents
9d6a3704
ab250afd
Changes
332
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
148 additions
and
175 deletions
+148
-175
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+49
-0
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+62
-3
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+4
-3
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp
...ile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp
+0
-25
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+2
-2
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
...mm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+25
-136
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
...emm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
...mm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp
...emm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp
...mm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
332 of 332+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
e6bb1dd7
...
...
@@ -43,4 +43,53 @@ struct TileFmhaShape
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
};
template
<
typename
BlockTile_
,
// sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1WarpTile_
,
typename
Gemm2BlockWarps_
,
typename
Gemm2WarpTile_
,
typename
Gemm3BlockWarps_
,
typename
Gemm3WarpTile_
,
typename
Gemm4BlockWarps_
,
typename
Gemm4WarpTile_
>
struct
TileFmhaBwdShape
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
Gemm0BlockWarps
=
remove_cvref_t
<
Gemm0BlockWarps_
>
;
using
Gemm0WarpTile
=
remove_cvref_t
<
Gemm0WarpTile_
>
;
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
using
Gemm2BlockWarps
=
remove_cvref_t
<
Gemm2BlockWarps_
>
;
using
Gemm2WarpTile
=
remove_cvref_t
<
Gemm2WarpTile_
>
;
using
Gemm3BlockWarps
=
remove_cvref_t
<
Gemm3BlockWarps_
>
;
using
Gemm3WarpTile
=
remove_cvref_t
<
Gemm3WarpTile_
>
;
using
Gemm4BlockWarps
=
remove_cvref_t
<
Gemm4BlockWarps_
>
;
using
Gemm4WarpTile
=
remove_cvref_t
<
Gemm4WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
Gemm0BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumWarps
==
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{})
&&
NumWarps
==
reduce_on_sequence
(
Gemm4BlockWarps
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along gemm0(Q@K^T) unroll
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along gemm1(P^T@dO) unroll
static
constexpr
index_t
kK2
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along gemm2(dO@V^T) unroll
static
constexpr
index_t
kK3
=
BlockTile
::
at
(
number
<
5
>
{});
// tile size along gemm3(dS^T@Q) unroll
static
constexpr
index_t
kK4
=
BlockTile
::
at
(
number
<
6
>
{});
// tile size along gemm4(dS@K) unroll
static
constexpr
index_t
kQKHeaddim
=
BlockTile
::
at
(
number
<
7
>
{});
// Q & K headdim, used for pipeline that need load Q/Q^T or
// K/K^T at once
static
constexpr
index_t
kVHeaddim
=
BlockTile
::
at
(
number
<
8
>
{});
// V headdim, used for pipeline
// that need load V at once
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
namespace
ck_tile
{
...
...
@@ -11,8 +12,10 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kHasBias_
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kHasDropout_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaTraits
...
...
@@ -21,10 +24,66 @@ struct TileFmhaTraits
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kHasBias
=
kHasBias_
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kHasBiasGrad
=
kHasBiasGrad_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kHasDropout
=
kHasDropout_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ
/* padding for seqlen_q */
,
bool
kPadSeqLenK
/* padding for seqlen_k */
,
bool
kPadHeadDimQ
/* paddding for hdim_q */
,
bool
kPadHeadDimV
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum
,
bool
kHasBiasGrad
,
bool
kStoreLSE
,
bool
kHasDropout
,
bool
kDoFp8StaticQuant
,
bool
kHasUnevenSplits_
=
true
,
index_t
kBlockPerCu
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
:
TileFmhaTraits
<
kPadSeqLenQ
,
kPadSeqLenK
,
kPadHeadDimQ
,
kPadHeadDimV
,
BiasEnum
,
kHasBiasGrad
,
kStoreLSE
,
kHasDropout
,
kDoFp8StaticQuant
,
kBlockPerCu
>
{
// determine if some split (length) is not divisible by tile size
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kStoreLSE_
,
bool
kDoFp8StaticQuant_
,
index_t
kLogMaxSplits_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVCombineTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kMaxSplits
=
(
1
<<
kLogMaxSplits_
);
static_assert
(
kMaxSplits
<=
get_warp_size
()
||
kMaxSplits
%
get_warp_size
()
==
0
);
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
index_t
kBlockPerCu_
=
2
/* hint to occupancy */
>
struct
TileFmhaBwdOGradDotOTraits
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
e6bb1dd7
...
...
@@ -3,20 +3,21 @@
#pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp
deleted
100644 → 0
View file @
9d6a3704
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// Problem Description for BlockGemmARegBGmemCReg
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmARegBGmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using
BlockGemmARegBSmemCRegImpl
=
BlockGemmARegBSmemCRegV1
<
BlockGemm
ARegBSmemCReg
Problem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_b
g
mem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_b
s
mem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
...
...
@@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
//
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
//
KPerBlock == BlockGemmShape::kK,
//
"wrong!");
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
@@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1
});
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensorTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensorTmp
{}.
get_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
...
...
@@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iNWarp
=
get_warp_id
()
%
NWarp
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// Construct C-Block-HostTensor
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block tensor
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window_tmp
);
return
c_block_tensor
;
}
};
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment