Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3dc5db72
Commit
3dc5db72
authored
Oct 21, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
b924e330
e547c141
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
688 additions
and
301 deletions
+688
-301
include/ck_tile/host/reference/reference_im2col.hpp
include/ck_tile/host/reference/reference_im2col.hpp
+117
-45
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+171
-0
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+77
-14
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+70
-13
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+22
-29
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+8
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+21
-23
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+6
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+61
-56
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+55
-45
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+15
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+4
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+35
-31
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+6
-5
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+6
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+7
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-2
No files found.
include/ck_tile/host/reference/reference_im2col.hpp
View file @
3dc5db72
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,53 +9,125 @@
namespace
ck_tile
{
template
<
typename
T
>
CK_TILE_HOST
void
reference_im2col
(
HostTensor
<
T
>&
in_mtx_host_ref
,
const
HostTensor
<
T
>&
in_host
,
int
/*N*/
,
int
/*K*/
,
int
C
,
int
/*Y*/
,
int
X
,
int
Hi
,
int
Wi
,
int
Ho
,
int
Wo
,
int
ConvStrideH
,
int
ConvStrideW
,
int
ConvDilationH
,
int
ConvDilationW
,
int
InLeftPadH
,
int
InLeftPadW
,
int
/*InRightPadH*/
,
int
/*InRightPadW*/
)
template
<
typename
InDataType
,
typename
OutDataType
,
index_t
NDimSpatial
>
CK_TILE_HOST
void
reference_im2col
(
const
HostTensor
<
InDataType
>&
in_host
,
HostTensor
<
OutDataType
>&
out_host
,
const
ck_tile
::
conv
::
ConvParam
&
conv_params
)
{
int
GemmM
=
in_mtx_host_ref
.
get_lengths
()[
0
];
int
GemmK
=
in_mtx_host_ref
.
get_lengths
()[
1
];
const
long_index_t
G
=
in_host
.
get_lengths
()[
0
];
const
long_index_t
N
=
in_host
.
get_lengths
()[
1
];
const
long_index_t
C
=
in_host
.
get_lengths
()[
2
];
for
(
int
gemm_m
=
0
;
gemm_m
<
GemmM
;
++
gemm_m
)
if
constexpr
(
NDimSpatial
==
1
)
{
int
mtmp
=
gemm_m
;
int
n
=
mtmp
/
(
Ho
*
Wo
);
mtmp
-=
n
*
Ho
*
Wo
;
int
ho
=
mtmp
/
Wo
;
int
wo
=
mtmp
-
ho
*
Wo
;
for
(
int
gemm_k
=
0
;
gemm_k
<
GemmK
;
++
gemm_k
)
{
int
ktmp
=
gemm_k
;
int
y
=
ktmp
/
(
X
*
C
);
ktmp
-=
y
*
X
*
C
;
int
x
=
ktmp
/
C
;
int
c
=
ktmp
-
x
*
C
;
int
hi
=
y
*
ConvDilationH
+
ho
*
ConvStrideH
-
InLeftPadH
;
int
wi
=
x
*
ConvDilationW
+
wo
*
ConvStrideW
-
InLeftPadW
;
bool
inbound
=
(
hi
>=
0
&&
hi
<
Hi
&&
wi
>=
0
&&
wi
<
Wi
);
in_mtx_host_ref
(
gemm_m
,
gemm_k
)
=
inbound
?
in_host
(
n
,
hi
,
wi
,
c
)
:
0
;
}
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
0
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
wo
)
{
long_index_t
row
=
n
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
3
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
1
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
3
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
4
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
const
long_index_t
Do
=
conv_params
.
output_spatial_lengths_
[
0
];
const
long_index_t
Ho
=
conv_params
.
output_spatial_lengths_
[
1
];
const
long_index_t
Wo
=
conv_params
.
output_spatial_lengths_
[
2
];
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
long_index_t
row
=
n
*
Do
*
Ho
*
Wo
+
d_o
*
Ho
*
Wo
+
ho
*
Wo
+
wo
;
long_index_t
column
=
0
;
for
(
long_index_t
z
=
0
;
z
<
conv_params
.
filter_spatial_lengths_
[
0
];
++
z
)
{
auto
di
=
static_cast
<
long_index_t
>
(
d_o
*
conv_params
.
conv_filter_strides_
[
0
])
+
static_cast
<
long_index_t
>
(
z
*
conv_params
.
conv_filter_dilations_
[
0
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
0
]);
for
(
long_index_t
y
=
0
;
y
<
conv_params
.
filter_spatial_lengths_
[
1
];
++
y
)
{
auto
hi
=
static_cast
<
long_index_t
>
(
ho
*
conv_params
.
conv_filter_strides_
[
1
])
+
static_cast
<
long_index_t
>
(
y
*
conv_params
.
conv_filter_dilations_
[
1
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
1
]);
for
(
long_index_t
x
=
0
;
x
<
conv_params
.
filter_spatial_lengths_
[
2
];
++
x
)
{
auto
wi
=
static_cast
<
long_index_t
>
(
wo
*
conv_params
.
conv_filter_strides_
[
2
])
+
static_cast
<
long_index_t
>
(
x
*
conv_params
.
conv_filter_dilations_
[
2
])
-
static_cast
<
long_index_t
>
(
conv_params
.
input_left_pads_
[
2
]);
for
(
long_index_t
c
=
0
;
c
<
C
;
++
c
)
{
if
(
di
>=
0
&&
type_convert
<
std
::
size_t
>
(
di
)
<
in_host
.
get_lengths
()[
3
]
&&
hi
>=
0
&&
type_convert
<
std
::
size_t
>
(
hi
)
<
in_host
.
get_lengths
()[
4
]
&&
wi
>=
0
&&
type_convert
<
std
::
size_t
>
(
wi
)
<
in_host
.
get_lengths
()[
5
])
{
InDataType
v_in
=
in_host
(
g
,
n
,
c
,
di
,
hi
,
wi
);
out_host
(
g
,
row
,
column
)
=
type_convert
<
OutDataType
>
(
v_in
);
}
column
++
;
}
}
}
}
};
make_ParallelTensorFunctor
(
func
,
G
,
N
,
Do
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
}
}
// namespace ck_tile
include/ck_tile/ops/epilogue.hpp
View file @
3dc5db72
...
...
@@ -3,5 +3,6 @@
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
0 → 100644
View file @
3dc5db72
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
));
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
3dc5db72
...
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
integer_divide_ceil
(
x_total
,
num_splits
)
)
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
)
;
const
index_t
split_end
=
split_start
+
x_per_split
;
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
3dc5db72
...
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
...
@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
struct
FmhaBwd
Common
Dropout
Kargs
struct
FmhaBwdDropout
SeedOffset
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
,
const
float
raw_scale
)
template
<
typename
T
>
union
ValueOrPointer
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaBwdCommonDropoutKargs
:
FmhaBwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
,
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
,
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
...
...
@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
...
...
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
...
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
...
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
3dc5db72
...
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
...
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
struct
FmhaFwd
Common
Dropout
Kargs
struct
FmhaFwdDropout
SeedOffset
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
template
<
typename
T
>
union
ValueOrPointer
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaFwdCommonDropoutKargs
:
FmhaFwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
...
...
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
...
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
...
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return
BlockDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
3dc5db72
...
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
...
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
...
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
};
struct
GroupModeKargs
...
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
...
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
seqlen_q
,
hdim_v
,
num_splits
,
...
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
,
batch_stride_lse_acc
};
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
...
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
...
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
num_splits
,
...
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
...
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
hdim_v
_
)
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
return
TilePartitioner
::
GridSize
(
batch_size
_
,
nhead
_
,
seqlen_q
_
,
hdim_v
_
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
...
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
row_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
...
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
// for simplicity, batch stride we just modify the pointer
...
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
...
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_
max_
seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_
max_
seqlen_q
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
}
else
...
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
View file @
3dc5db72
...
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
hdim_v
_
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
_
,
kN1
),
nhead
_
,
batch_size
_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
3dc5db72
...
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
...
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
};
struct
GroupModeKargs
...
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
...
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
,
batch_stride_o_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
,
// only used for paged-kvcache
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
...
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
...
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
hdim_v
,
num_splits
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
,
num_splits
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_o_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
...
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
stride_o_acc
;
// get real # queries & # keys under group mode
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
...
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
hdim_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
3dc5db72
...
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
batch_size
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
3dc5db72
...
...
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
s_acc
,
bias_s_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
...
...
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
...
...
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
...
...
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
...
...
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
...
...
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
__builtin_amdgcn_sched_barrier
(
0
);
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
3dc5db72
...
...
@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/
pipeline
/block_gemm_
pipeline_
problem.hpp"
#include "ck_tile/ops/gemm/
block
/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...
...
@@ -25,15 +25,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
...
...
@@ -52,21 +53,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -84,21 +86,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
...
...
@@ -117,21 +120,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -149,21 +153,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
// these are for global load
...
...
@@ -1727,7 +1732,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
0
>
()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
...
...
@@ -1759,7 +1764,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
1
>
()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
...
...
@@ -1777,7 +1782,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
2
>
()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
...
...
@@ -1796,7 +1801,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
3
>
()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
...
...
@@ -1830,7 +1835,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
4
>
()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
3dc5db72
...
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
// lse_acc tile in LDS
...
...
@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
LSEDataType
>::
infinity
());
block_tile_reduce_sync
(
lse_max
,
f_max
,
bool_constant
<
false
>
{});
static
const
auto
get_validated_m
=
[](
LSEDataType
raw_m
)
{
return
raw_m
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_m
;
};
decltype
(
lse_accum
)
lse_exp
;
{
constexpr
auto
spans
=
decltype
(
lse_exp
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
(
lse_max
[
i_idx
]
==
-
numeric
<
LSEDataType
>::
infinity
())
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
lse_exp
(
i_j_idx
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
get_validated_m
(
lse_max
(
i_idx
)));
});
lse_exp
(
i_j_idx
)
=
ck_tile
::
type_convert
<
LSEDataType
>
(
0.0
f
);
});
}
else
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
lse_exp
(
i_j_idx
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
lse_max
(
i_idx
));
});
}
});
}
...
...
@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_sum
(
i_idx
)
==
0.
f
||
lse_sum
(
i_idx
)
!=
lse_sum
(
i_idx
))
{
lse_logsum
(
i_idx
)
=
numeric
<
LSEDataType
>::
infinity
();
}
if
(
lse_sum
[
i_idx
]
==
ck_tile
::
type_convert
<
LSEDataType
>
(
0.0
f
))
lse_logsum
(
i_idx
)
=
-
numeric
<
LSEDataType
>::
infinity
();
else
{
lse_logsum
(
i_idx
)
=
ck_tile
::
log
(
lse_sum
(
i_idx
))
+
get_validated_m
(
lse_max
(
i_idx
));
}
lse_logsum
(
i_idx
)
=
ck_tile
::
log
(
lse_sum
(
i_idx
))
+
lse_max
(
i_idx
);
});
}
...
...
@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
(
lse_logsum
(
i_idx
)
==
-
numeric
<
LSEDataType
>::
infinity
())
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
lse_accum
.
get_tile_distribution
(),
i_j_idx
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
lse_accum
.
get_tile_distribution
(),
i_j_idx
);
const
auto
col
=
x_indices
.
at
(
number
<
1
>
{});
if
(
col
<
num_splits
)
{
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
const
auto
col
=
x_indices
.
at
(
number
<
1
>
{});
if
(
col
<
num_splits
)
{
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
lse_acc_lds
(
row
,
col
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
lse_logsum
(
i_idx
));
}
});
lse_acc_lds
(
row
,
col
)
=
ck_tile
::
type_convert
<
LSEDataType
>
(
0.0
f
);
}
});
}
else
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
lse_accum
.
get_tile_distribution
(),
i_j_idx
);
const
auto
col
=
x_indices
.
at
(
number
<
1
>
{});
if
(
col
<
num_splits
)
{
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
lse_acc_lds
(
row
,
col
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
lse_logsum
(
i_idx
));
}
});
}
});
}
block_sync_lds
();
if
constexpr
(
kStoreLSE
)
{
constexpr
auto
spans
=
decltype
(
lse_logsum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_logsum
(
i_idx
)
==
numeric
<
LSEDataType
>::
infinity
())
{
lse_logsum
(
i_idx
)
=
-
numeric
<
LSEDataType
>::
infinity
();
}
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
}
...
...
@@ -261,7 +271,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
const
index_t
padded_
max_
seqlen_q
=
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
kM0
;
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
{
...
...
@@ -282,7 +292,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
move_tile_window
(
o_acc_dram_window
,
{
padded_
max_
seqlen_q
,
0
});
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
}
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
...
...
@@ -297,7 +307,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
return
operator
()(
lse_acc_dram_block_window
,
...
...
@@ -306,7 +316,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
num_splits
,
max_
seqlen_q
,
seqlen_q
,
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
3dc5db72
...
...
@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
16
/
sizeof
(
OaccDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
return
min
(
N1
,
static_cast
<
index_t
>
(
16
/
sizeof
(
OaccDataType
)));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentO
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
return
16
/
sizeof
(
ODataType
);
return
GetAlignmentOacc
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOaccDramTileDistribution
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
N1
=
16
/
sizeof
(
OaccDataType
);
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M2
=
get_warp_size
()
/
N0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
3dc5db72
...
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if
masked and
no work to do
.
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
...
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
3dc5db72
...
...
@@ -5,7 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...
...
@@ -75,15 +76,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
@@ -199,15 +201,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmASmemBSmemCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
@@ -954,15 +957,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
include/ck_tile/ops/gemm.hpp
View file @
3dc5db72
...
...
@@ -23,12 +23,13 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
3dc5db72
...
...
@@ -11,20 +11,12 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
...
...
@@ -32,6 +24,10 @@ struct GemmKernel
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutC
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
...
...
@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
}
};
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
3dc5db72
...
...
@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
Problem
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
Problem
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
Problem
::
LayoutC
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
3dc5db72
...
...
@@ -7,9 +7,9 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment