Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8c967d76
Commit
8c967d76
authored
Jul 11, 2024
by
danyao12
Browse files
fix batch deterministic bugs
parent
74f1516c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
188 deletions
+41
-188
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+10
-16
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+8
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+3
-36
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+20
-127
No files found.
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
8c967d76
...
@@ -766,7 +766,7 @@ struct FmhaBwdDQDKDVKernel
...
@@ -766,7 +766,7 @@ struct FmhaBwdDQDKDVKernel
make_naive_tensor_view
<
address_space_enum
::
global
>
(
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
hdim
_q
,
1
),
make_tuple
(
kargs
.
stride
_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -1487,22 +1487,18 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1487,22 +1487,18 @@ struct FmhaBwdConvertQGradKernel
{
{
const
AccDataType
*
dq_acc_ptr
=
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
seqlen_q
*
kargs
.
hdim_q
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq
)
+
batch_offset_dq
;
batch_offset_dq
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
constexpr
auto
dq_fold
=
4
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
/
dq_fold
,
kargs
.
hdim_q
*
dq_fold
),
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
hdim_q
*
dq_fold
,
1
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
number
<
kM0
/
dq_fold
>
{},
number
<
kQKHeaddim
*
dq_fold
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}
else
else
...
@@ -1538,12 +1534,10 @@ struct FmhaBwdConvertQGradKernel
...
@@ -1538,12 +1534,10 @@ struct FmhaBwdConvertQGradKernel
auto
dq_acc_dram_window
=
[
&
]()
{
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
if
constexpr
(
kIsDeterministic
)
{
{
constexpr
auto
dq_fold
=
4
;
return
make_tile_window
(
return
make_tile_window
(
dq_acc_dram
,
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
number
<
kM0
/
dq_fold
>
{},
{
0
,
i_m0
,
0
});
number
<
kQKHeaddim
*
dq_fold
>
{}),
{
0
,
i_m0
/
dq_fold
,
0
});
}
}
else
else
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
View file @
8c967d76
...
@@ -52,7 +52,7 @@ struct BlockFmhaBwdConvertQGrad
...
@@ -52,7 +52,7 @@ struct BlockFmhaBwdConvertQGrad
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGrad
Acc
DramTileDistribution
<
Problem
>());
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>());
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
auto
dq_acc
=
load_tile
(
dq_acc_dram_window
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
const
auto
dq
=
cast_tile
<
QGradDataType
>
(
dq_acc
);
...
@@ -76,11 +76,11 @@ struct BlockFmhaBwdConvertQGrad
...
@@ -76,11 +76,11 @@ struct BlockFmhaBwdConvertQGrad
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
static_assert
(
kM0
==
QGradDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}],
"wrong!"
);
auto
dq_acc_dram_window
=
make_tile_window
(
auto
dq_acc_dram_window
=
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
dq_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_lengths
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
dq_acc_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakePostQGradAccD
eterministicD
ramTileDistribution
<
Problem
>());
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>());
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
auto
dq_acc
=
decltype
(
load_tile
(
dq_acc_dram_window
)){};
clear_tile
(
dq_acc
);
clear_tile
(
dq_acc
);
...
@@ -118,7 +118,7 @@ struct BlockFmhaBwdConvertQGrad
...
@@ -118,7 +118,7 @@ struct BlockFmhaBwdConvertQGrad
// declare dq
// declare dq
constexpr
auto
dq_converted_dstr
=
constexpr
auto
dq_converted_dstr
=
Policy
::
template
MakePostQGradAccD
eterministicD
ramTileDistribution
<
Problem
>();
Policy
::
template
MakePostQGradAccDramTileDistribution
<
Problem
>();
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
auto
dq_converted
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_converted_dstr
);
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
dq_acc_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
@@ -130,8 +130,7 @@ struct BlockFmhaBwdConvertQGrad
...
@@ -130,8 +130,7 @@ struct BlockFmhaBwdConvertQGrad
});
});
});
});
constexpr
auto
dq_dstr
=
constexpr
auto
dq_dstr
=
Policy
::
template
MakePostQGradDramTileDistribution
<
Problem
>();
Policy
::
template
MakePostQGradDeterministicDramTileDistribution
<
Problem
>();
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
auto
dq
=
make_static_distributed_tensor
<
QGradDataType
>
(
dq_dstr
);
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
dq
.
get_thread_buffer
()
=
dq_converted
.
get_thread_buffer
();
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
8c967d76
...
@@ -473,28 +473,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -473,28 +473,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dq_dram_block_window_tmp
.
get_window_lengths
(),
dq_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_q_start
,
0
});
{
seqlen_q_start
,
0
});
// Deterministic mode staff
auto
dq_buffer_view
=
dq_dram_block_window_tmp
.
get_bottom_tensor_view
().
get_buffer_view
();
auto
dq_tensor_desc
=
dq_dram_block_window_tmp
.
get_bottom_tensor_view
().
get_tensor_descriptor
();
auto
seqlen_q
=
dq_tensor_desc
.
get_lengths
()[
number
<
0
>
{}];
auto
hdim_q
=
dq_tensor_desc
.
get_lengths
()[
number
<
1
>
{}];
constexpr
auto
dq_fold
=
4
;
auto
dq_write_tensor_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
seqlen_q
/
dq_fold
,
hdim_q
*
dq_fold
),
make_tuple
(
hdim_q
*
dq_fold
,
1
),
number
<
kAlignmentQGrad
>
{},
number
<
1
>
{});
auto
dq_tensor_view
=
tensor_view
<
decltype
(
dq_buffer_view
),
decltype
(
dq_write_tensor_desc
)
>
{
dq_buffer_view
,
dq_write_tensor_desc
};
auto
dq_dram_window_deterministic
=
make_tile_window
(
dq_tensor_view
,
make_tuple
(
number
<
kM0
/
dq_fold
>
{},
number
<
kQKHeaddim
*
dq_fold
>
{}),
{
seqlen_q_start
/
dq_fold
,
0
});
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPTBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
SPGradTBlockTileType
=
decltype
(
gemm_2
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
using
QGradBlockTileType
=
decltype
(
gemm_4
.
MakeCBlockTile
());
...
@@ -807,19 +785,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -807,19 +785,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}
}
if
constexpr
(
kIsDeterministic
)
if
constexpr
(
kIsDeterministic
)
{
{
auto
dq_write_reg_tensor
=
make_static_distributed_tensor
<
AccDataType
>
(
store_tile
(
dq_dram_window
,
dq_acc
);
Policy
::
template
MakeQGradWriteBlockDescriptor
<
Problem
>());
dq_write_reg_tensor
.
get_thread_buffer
()
=
dq_acc
.
get_thread_buffer
();
store_tile
(
dq_dram_window_deterministic
,
dq_write_reg_tensor
);
move_tile_window
(
dq_dram_window_deterministic
,
{
kM0
/
dq_fold
,
0
});
}
}
else
else
{
{
update_tile
(
dq_dram_window
,
dq_acc
);
update_tile
(
dq_dram_window
,
dq_acc
);
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
}
}
move_tile_window
(
dq_dram_window
,
{
kM0
,
0
});
i_total_loops
+=
1
;
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
seqlen_q_step
+=
kM0
;
...
@@ -1047,12 +1019,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -1047,12 +1019,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
if
constexpr
(
kIsDeterministic
)
if
constexpr
(
kIsDeterministic
)
{
{
auto
dq_write_reg_tensor
=
make_static_distributed_tensor
<
AccDataType
>
(
store_tile
(
dq_dram_window
,
dq_acc
);
Policy
::
template
MakeQGradWriteBlockDescriptor
<
Problem
>());
dq_write_reg_tensor
.
get_thread_buffer
()
=
dq_acc
.
get_thread_buffer
();
store_tile
(
dq_dram_window_deterministic
,
dq_write_reg_tensor
);
}
}
else
else
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
8c967d76
...
@@ -167,7 +167,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -167,7 +167,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
Problem
::
kIsDeterministic
?
true
:
false
>
;
false
>
;
using
BlockGemmPolicy
=
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
GemmDataType
,
...
@@ -534,91 +534,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -534,91 +534,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePostQGradAccDeterministicDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePostQGradAccDramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QGradDataType
,
typename
Problem
::
QGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
Shape
::
WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
Shape
::
WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
Shape
::
WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
WarpGemmAttrImpl
=
typename
WarpGemm
::
WarpGemmAttribute
::
Impl
;
constexpr
index_t
MWarp
=
Problem
::
Shape
::
BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
Shape
::
BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
Shape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
Shape
::
kQKHeaddim
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
auto
dq_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
1
>
,
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
2
,
3
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
0
>>
{};
constexpr
auto
dq_block_inner_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
1
>
,
sequence
<
WarpGemmAttrImpl
::
kCM0PerLane
,
WarpGemmAttrImpl
::
kCMLane
>
,
sequence
<
WarpGemmAttrImpl
::
kCNLane
,
WarpGemmAttrImpl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
2
,
3
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{};
constexpr
auto
dq_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dq_block_outer_dstr_encoding
,
dq_block_inner_dstr_encoding
);
constexpr
auto
dq_block_dstr
=
make_static_tile_distribution
(
dq_block_dstr_encode
);
return
dq_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePostQGradDeterministicDramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QGradDataType
,
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
typename
Problem
::
QGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
Shape
::
WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
Shape
::
WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
Shape
::
WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
constexpr
index_t
MWarp
=
Problem
::
Shape
::
BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
Shape
::
BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
Shape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
Shape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
Shape
::
kQKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
Shape
::
kQKHeaddim
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
auto
dq_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
dq_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
dq_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{})
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
auto
dq_block_dstr
=
make_static_tile_distribution
(
dq_block_dstr_encode
);
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
kMPerBlock
/
(
M1
*
M2
);
return
dq_block_dstr
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
1
>
,
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
3
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePostQGrad
Acc
DramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakePostQGradDramTileDistribution
()
{
{
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
...
@@ -1079,7 +1020,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1079,7 +1020,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
Problem
::
kIsDeterministic
?
true
:
false
>
;
false
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
...
@@ -1554,7 +1495,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1554,7 +1495,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
::
at
(
number
<
2
>
{}),
Problem
::
kIsDeterministic
?
true
:
false
>
;
false
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
...
@@ -1581,54 +1522,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1581,54 +1522,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return
ds_block_dstr
;
return
ds_block_dstr
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQGradWriteBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
::
at
(
number
<
2
>
{}),
true
>
;
using
WarpGemmAttrImpl
=
typename
WarpGemm
::
WarpGemmAttribute
::
Impl
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
auto
dq_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
dq_block_inner_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
WarpGemmAttrImpl
::
kCM0PerLane
,
WarpGemmAttrImpl
::
kCMLane
>
,
sequence
<
WarpGemmAttrImpl
::
kCNLane
,
WarpGemmAttrImpl
::
kCM1PerLane
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{};
constexpr
auto
dq_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
dq_block_outer_dstr_encoding
,
dq_block_inner_dstr_encoding
);
constexpr
auto
dq_block_dstr
=
make_static_tile_distribution
(
dq_block_dstr_encode
);
return
dq_block_dstr
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasTLdsBlockDescriptor
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment