Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
036c5234
Commit
036c5234
authored
May 14, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
22995e9a
7843a8a7
Changes
207
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1156 additions
and
122 deletions
+1156
-122
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-0
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
+37
-0
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+189
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+72
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
...de/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+40
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+44
-14
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+13
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+40
-11
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+3
-2
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+24
-19
library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp
.../library/tensor_operation_instance/gpu/gemm_multi_abd.hpp
+303
-21
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+159
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+52
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp
.../device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+24
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp
...peration_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp
+21
-21
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp
.../tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp
+108
-0
No files found.
include/ck_tile/ops/fmha.hpp
View file @
036c5234
...
...
@@ -3,7 +3,9 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
...
...
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
BlockAttentionBiasEnum
{
NO_BIAS
=
0
,
ELEMENTWISE_BIAS
=
1
,
// attention bias, each elements add to the result of Q*K(after scale)
ALIBI
=
2
,
// bias computed with position encoding, applied after scale
};
template
<
BlockAttentionBiasEnum
>
struct
BlockAttentionBiasEnumToStr
;
template
<
>
struct
BlockAttentionBiasEnumToStr
<
BlockAttentionBiasEnum
::
NO_BIAS
>
{
static
constexpr
const
char
*
name
=
""
;
};
template
<
>
struct
BlockAttentionBiasEnumToStr
<
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
>
{
static
constexpr
const
char
*
name
=
"bias"
;
};
template
<
>
struct
BlockAttentionBiasEnumToStr
<
BlockAttentionBiasEnum
::
ALIBI
>
{
static
constexpr
const
char
*
name
=
"alibi"
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace
ck_tile
{
enum
struct
PositionEncodingEnum
{
NO
=
0
,
ALIBI
=
1
,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum
struct
AlibiMode
{
VERTICAL
=
0
,
FROM_TOP_LEFT
=
1
,
// keep sync with mask enum
FROM_BOTTOM_RIGHT
=
2
,
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
struct
Alibi
{
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE
Alibi
(
DataType
slope_
,
index_t
y_total_
,
index_t
x_total_
,
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
{
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
;
shift_left_up
=
[
&
]()
{
if
(
RowMajor
)
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
y_total_
-
x_total_
,
0
)
:
0
;
}
else
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
x_total_
-
y_total_
,
0
)
:
0
;
}
}();
shift_right_down
=
[
&
]()
{
if
(
RowMajor
)
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
x_total_
-
y_total_
,
0
)
:
0
;
}
else
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
y_total_
-
x_total_
,
0
)
:
0
;
}
}();
mode
=
mode_
;
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
if
constexpr
(
RowMajor
)
{
// at least 3 instructions per row
index_t
current_zero_point
=
mode
==
AlibiMode
::
VERTICAL
?
shift_right_down
:
row_idx
+
shift_right_down
;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto
position
=
type_convert
<
DataType
>
(
sad
(
bit_cast
<
uint32_t
>
(
current_zero_point
),
bit_cast
<
uint32_t
>
(
col_idx
+
shift_left_up
),
0
));
pixel
+=
slope
*
position
;
}
else
{
// at least 3 instructions per col;
index_t
current_zero_point
=
mode
==
AlibiMode
::
VERTICAL
?
row_idx
+
col_idx
+
shift_right_down
:
col_idx
+
shift_right_down
;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto
position
=
type_convert
<
DataType
>
(
sad
(
bit_cast
<
uint32_t
>
(
current_zero_point
),
bit_cast
<
uint32_t
>
(
row_idx
+
shift_left_up
),
0
));
pixel
+=
slope
*
position
;
}
}
DataType
slope
;
// float?
index_t
shift_left_up
;
// always possitive
index_t
shift_right_down
;
// always possitive
AlibiMode
mode
;
};
template
<
typename
DataType
>
struct
EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
/*pixel*/
,
index_t
/*row_idx*/
,
index_t
/*col_idx*/
)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_right_size
,
index_t
y_total
,
index_t
x_total
,
GenericAttentionMaskEnum
mask_enum
)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool
is_causal
=
window_left_size
<
0
&&
window_right_size
==
0
;
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template
<
typename
DataType
>
CK_TILE_HOST
std
::
vector
<
DataType
>
get_alibi_slopes
(
ck_tile
::
index_t
nheads
)
{
auto
get_slopes_power_of_2
=
[](
ck_tile
::
index_t
n
)
{
float
start
=
std
::
powf
(
static_cast
<
float
>
(
2
),
-
std
::
powf
(
static_cast
<
float
>
(
2
),
-
static_cast
<
float
>
((
integer_log2_floor
(
n
)
-
3
))));
std
::
vector
<
DataType
>
rtn
;
for
(
auto
i
=
0
;
i
<
n
;
i
++
)
{
rtn
.
push_back
(
static_cast
<
DataType
>
(
start
*
std
::
powf
(
start
,
i
)));
}
return
rtn
;
};
if
(
is_power_of_two_integer
(
nheads
))
{
// power of 2 calculation
return
get_slopes_power_of_2
(
nheads
);
}
else
{
ck_tile
::
index_t
closest_power_of_2
=
1
<<
integer_log2_floor
(
nheads
);
auto
v0
=
get_slopes_power_of_2
(
closest_power_of_2
);
auto
v1
=
get_slopes_power_of_2
(
closest_power_of_2
*
2
);
auto
v1_sliced
=
[
&
](
auto
vec
,
ck_tile
::
index_t
rem
)
{
std
::
vector
<
DataType
>
sliced
;
for
(
ck_tile
::
index_t
i
=
0
;
i
<
static_cast
<
ck_tile
::
index_t
>
(
vec
.
size
());
i
++
)
{
if
(
i
%
2
==
0
)
sliced
.
push_back
(
vec
[
i
]);
}
std
::
vector
<
DataType
>
sliced_2
(
sliced
.
begin
(),
sliced
.
begin
()
+
rem
);
return
sliced_2
;
}(
v1
,
nheads
-
closest_power_of_2
);
v0
.
insert
(
v0
.
end
(),
v1_sliced
.
begin
(),
v1_sliced
.
end
());
return
v0
;
}
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
036c5234
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
...
...
@@ -33,6 +34,7 @@ struct FmhaFwdKernel
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
ODataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
...
...
@@ -41,7 +43,7 @@ struct FmhaFwdKernel
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
FmhaPipeline
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
...
...
@@ -81,7 +83,8 @@ struct FmhaFwdKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kHasBias
?
"_bias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -136,6 +139,13 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_bias
=
0
;
};
struct
FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const
void
*
alibi_slope_ptr
;
ck_tile
::
index_t
alibi_slope_stride
;
// stride in batch, or 0 for all batch share same slope
};
struct
FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
...
...
@@ -162,7 +172,11 @@ struct FmhaFwdKernel
struct
FmhaFwdBatchModeKargs
:
FmhaFwdCommonKargs
,
std
::
conditional_t
<
kHasBias
,
FmhaFwdBatchModeBiasKargs
,
FmhaFwdEmptyKargs
<
0
>>
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaFwdBatchModeBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaFwdAlibiKargs
,
FmhaFwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdBatchModeLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
...
...
@@ -175,7 +189,11 @@ struct FmhaFwdKernel
struct
FmhaFwdGroupModeKargs
:
FmhaFwdCommonKargs
,
std
::
conditional_t
<
kHasBias
,
FmhaFwdCommonBiasKargs
,
FmhaFwdEmptyKargs
<
0
>>
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaFwdCommonBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaFwdAlibiKargs
,
FmhaFwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdCommonLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
...
...
@@ -255,13 +273,18 @@ struct FmhaFwdKernel
batch_stride_v
,
batch_stride_o
};
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
...
...
@@ -345,12 +368,17 @@ struct FmhaFwdKernel
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
{
kargs
.
window_size_left
=
window_size_left
;
...
...
@@ -421,14 +449,10 @@ struct FmhaFwdKernel
{
batch_offset_v
=
key_start
;
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
else
{
batch_offset_bias
=
key_start
;
}
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
...
...
@@ -461,7 +485,7 @@ struct FmhaFwdKernel
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
...
...
@@ -585,7 +609,7 @@ struct FmhaFwdKernel
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
bias_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
BiasDataType
*
bias_ptr
=
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
...
...
@@ -654,6 +678,39 @@ struct FmhaFwdKernel
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
}();
// WA i_batch capture structure binding before c++20
auto
position_encoding
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType
slope
=
*
(
reinterpret_cast
<
const
SaccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
SaccDataType
,
true
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
VERTICAL
};
}
}
else
{
return
EmptyPositionEncoding
<
SaccDataType
>
{};
}
}();
auto
o_acc_tile
=
[
&
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -672,6 +729,7 @@ struct FmhaFwdKernel
scales
{
kargs
.
scale_p
},
// p_compute_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
);
}
...
...
@@ -683,6 +741,7 @@ struct FmhaFwdKernel
bias_dram_window
,
lse_dram_window
,
mask
,
position_encoding
,
kargs
.
scale_s
,
smem_ptr
);
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
View file @
036c5234
...
...
@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
QSKSVS
,
};
template
<
BlockFmhaPipelineEnum
>
struct
BlockFmhaPipelineEnumToStr
;
template
<
>
struct
BlockFmhaPipelineEnumToStr
<
BlockFmhaPipelineEnum
::
QRKSVS
>
{
static
constexpr
const
char
*
name
=
"qr"
;
};
template
<
>
struct
BlockFmhaPipelineEnumToStr
<
BlockFmhaPipelineEnum
::
QRKSVS_ASYNC
>
{
static
constexpr
const
char
*
name
=
"qr_async"
;
};
template
<
>
struct
BlockFmhaPipelineEnumToStr
<
BlockFmhaPipelineEnum
::
QSKSVS
>
{
static
constexpr
const
char
*
name
=
"qs"
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
036c5234
...
...
@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Traits
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Traits
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
036c5234
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
...
...
@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
...
@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
...
...
@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
...
...
@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
{
...
...
@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
...
...
@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
...
...
@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
...
@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
...
@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
...
...
@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
...
@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
{
...
...
@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
identity
{},
identity
{},
mask
,
position_encoding
,
scale_s
,
smem_ptr
);
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
036c5234
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
...
...
@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
...
@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
{
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
kHasBias
&&
FmhaMask
::
IsMasking
)
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
kHas
Bias
)
if
constexpr
(
kPadSeqLenK
&&
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
kHas
Bias
)
if
constexpr
(
kPadSeqLenK
&&
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
...
...
@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
...
...
@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
{
...
...
@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit
if masked and no work to do.
if
constexpr
(
FmhaMask
::
IsMasking
)
// check early exit
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
if
(
num_total_loop
<=
0
)
{
...
...
@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
...
...
@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
...
@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
/// consideration. alibi does not have this problem
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
...
@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
...
...
@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
...
@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
{
...
...
@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity
{},
identity
{},
mask
,
position_encoding
,
scale_s
,
smem_ptr
);
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
036c5234
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
...
...
@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
...
@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
...
...
@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
...
@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
/*lse_dram_window_tmp*/
,
// not supported
FmhaMask
mask
,
PositionEncoding
/*position_encoding*/
,
float
scale_s
,
float
descale_qk
,
float
descale_sv
,
...
...
@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
...
...
@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
...
...
@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
...
@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
036c5234
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace
ck_tile
{
...
...
@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
index_t
kBlockPerCu
=
[]()
{
...
...
@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
...
...
@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
...
...
@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
{
...
...
@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
k_block_tile
=
load_tile
(
k_dram_window
);
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
...
...
@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
...
...
@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
...
@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
...
@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
...
...
@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
...
...
@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
...
@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
void
*
smem_ptr
)
const
{
...
...
@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
identity
{},
identity
{},
mask
,
position_encoding
,
scale_s
,
smem_ptr
);
}
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
036c5234
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
namespace
ck_tile
{
...
...
@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kHasBias
_
,
BlockAttentionBiasEnum
BiasEnum
_
,
bool
kStoreLSE_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
...
...
@@ -21,7 +22,7 @@ struct TileFmhaTraits
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kHasBias
=
kHas
Bias_
;
static
constexpr
auto
BiasEnum
=
Bias
Enum
_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
036c5234
...
...
@@ -27,14 +27,16 @@ using Empty_Tuple = ck::Tuple<>;
using
BF16_Tuple
=
ck
::
Tuple
<
BF16
>
;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
BF16_BF16_Tuple
=
ck
::
Tuple
<
BF16
,
BF16
>
;
using
F64_Tuple
=
ck
::
Tuple
<
F64
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
using
I32_F32_Tuple
=
ck
::
Tuple
<
I32
,
F32
>
;
using
I8_Tuple
=
ck
::
Tuple
<
I8
>
;
using
BF16_Tuple
=
ck
::
Tuple
<
BF16
>
;
using
F32_F32_Tuple
=
ck
::
Tuple
<
F32
,
F32
>
;
...
...
@@ -91,23 +93,26 @@ using GK_Tuple = ck::Tuple<G_K>;
using
GK_GK_Tuple
=
ck
::
Tuple
<
G_K
,
G_K
>
;
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
TanH
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
using
AddRelu
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
AddSilu
=
ck
::
tensor_operation
::
element_wise
::
AddSilu
;
using
AddReluAdd
=
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
TanH
=
ck
::
tensor_operation
::
element_wise
::
TanH
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
using
MultiplyAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAddFastGelu
;
using
AddRelu
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
AddSilu
=
ck
::
tensor_operation
::
element_wise
::
AddSilu
;
using
AddReluAdd
=
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
MultiplyFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyFastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Multiply
=
ck
::
tensor_operation
::
element_wise
::
Multiply
;
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp
View file @
036c5234
...
...
@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
Scales
=
ck
::
tensor_operation
::
element_wise
::
Scales
;
using
Multiply
=
ck
::
tensor_operation
::
element_wise
::
Multiply
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
...
...
@@ -33,7 +33,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
...
...
@@ -46,7 +46,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
Add
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
...
...
@@ -59,7 +59,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances
(
...
...
@@ -72,7 +72,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>>&
instances
);
// RCR
...
...
@@ -86,7 +86,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances
(
...
...
@@ -99,7 +99,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
Add
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances
(
...
...
@@ -112,7 +112,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances
(
...
...
@@ -125,7 +125,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>>&
instances
);
// CRR
...
...
@@ -139,7 +139,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances
(
...
...
@@ -152,7 +152,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
Add
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances
(
...
...
@@ -165,7 +165,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances
(
...
...
@@ -178,8 +178,62 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>>&
instances
);
// Multiply
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
,
Row
>
,
Row
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<
I8
>
,
ck
::
Tuple
<
BF16
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyAddFastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<
I8
>
,
ck
::
Tuple
<
BF16
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyAdd
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<
I8
>
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
);
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
ck
::
Tuple
<
Row
>
,
Row
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<
I8
>
,
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
Multiply
>>>&
instances
);
#endif
// GEMM + Add + Gelu
...
...
@@ -201,7 +255,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
...
...
@@ -213,7 +267,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>
;
static
auto
GetInstances
()
...
...
@@ -271,7 +325,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
Add
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
...
...
@@ -283,7 +337,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
Add
>
;
static
auto
GetInstances
()
...
...
@@ -341,7 +395,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
...
...
@@ -353,7 +407,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>
;
static
auto
GetInstances
()
...
...
@@ -411,7 +465,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
...
...
@@ -423,7 +477,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>
;
static
auto
GetInstances
()
...
...
@@ -462,6 +516,234 @@ struct DeviceOperationInstanceFactory<
}
};
// Multiply
// GEMM + Add + Gelu
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
MultiplyAddFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
MultiplyAddFastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
AsDataType
,
ck
::
Tuple
<
BF16
>>
&&
is_same_v
<
BsDataType
,
ck
::
Tuple
<
I8
>>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<
BF16
,
BF16
>>
&&
is_same_v
<
EDataType
,
BF16
>
)
{
if
constexpr
(
is_same_v
<
AsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
BsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
Row
,
Row
>>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
// GEMM + Add
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
MultiplyAdd
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
MultiplyAdd
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
AsDataType
,
ck
::
Tuple
<
BF16
>>
&&
is_same_v
<
BsDataType
,
ck
::
Tuple
<
I8
>>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<
BF16
,
BF16
>>
&&
is_same_v
<
EDataType
,
BF16
>
)
{
if
constexpr
(
is_same_v
<
AsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
BsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
Row
,
Row
>>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
// GEMM + Gelu
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
MultiplyFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
MultiplyFastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
AsDataType
,
ck
::
Tuple
<
BF16
>>
&&
is_same_v
<
BsDataType
,
ck
::
Tuple
<
I8
>>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<
BF16
>>
&&
is_same_v
<
EDataType
,
BF16
>
)
{
if
constexpr
(
is_same_v
<
AsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
BsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
// GEMM
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
Multiply
>>
{
using
DeviceOp
=
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
PassThrough
,
PassThrough
,
Multiply
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
AsDataType
,
ck
::
Tuple
<
BF16
>>
&&
is_same_v
<
BsDataType
,
ck
::
Tuple
<
I8
>>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<
BF16
>>
&&
is_same_v
<
EDataType
,
BF16
>
)
{
if
constexpr
(
is_same_v
<
AsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
BsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
Row
>>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
036c5234
...
...
@@ -315,6 +315,107 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
...
...
@@ -494,6 +595,64 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
namespace
ck
::
tensor_layout
::
convolution
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdWeightDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Default
;
static
constexpr
auto
ConvBwdWeightFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
using
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
4
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp
View file @
036c5234
...
...
@@ -86,6 +86,7 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_bilinear_instances = std:
//#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
// generic instance
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
Tuple
<
BLayout
>
,
F16
,
F16
,
F16
,
F32
,
Tuple
<
F16
>
,
PassThrough
,
Bilinear
,
PassThrough
,
ConvSpec
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
1
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
Tuple
<
BLayout
>
,
F16
,
F16
,
F16
,
F32
,
Tuple
<
F16
>
,
PassThrough
,
Bilinear
,
PassThrough
,
ConvSpec
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
4
,
true
,
S
<
1
,
4
,
8
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
2
,
2
,
4
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
2
>
,
// instance for small conv.K
// for fp16 conv.K and conv.C must be divisible by 2
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
036c5234
...
...
@@ -352,6 +352,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
...
...
@@ -419,6 +421,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
036c5234
...
...
@@ -113,6 +113,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
...
...
@@ -192,6 +204,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp
View file @
036c5234
...
...
@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
Scales
=
ck
::
tensor_operation
::
element_wise
::
Scales
;
using
Multiply
=
ck
::
tensor_operation
::
element_wise
::
Multiply
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
...
...
@@ -32,7 +32,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_g
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances
(
...
...
@@ -45,7 +45,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_i
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
Add
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances
(
...
...
@@ -58,7 +58,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_i
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances
(
...
...
@@ -71,7 +71,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instan
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>>&
instances
);
// RCR
...
...
@@ -85,7 +85,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_g
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances
(
...
...
@@ -98,7 +98,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_i
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
Add
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances
(
...
...
@@ -111,7 +111,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_i
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances
(
...
...
@@ -124,7 +124,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instan
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>>&
instances
);
// CRR
...
...
@@ -138,7 +138,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_g
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances
(
...
...
@@ -151,7 +151,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_i
ck
::
Tuple
<
BF16
>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
Add
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances
(
...
...
@@ -164,7 +164,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_i
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances
(
...
...
@@ -177,7 +177,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>>&
instances
);
// GEMM + Add + Gelu
...
...
@@ -199,7 +199,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>>
{
using
DeviceOp
=
DeviceGroupedGemmMultiABDFixedNK
<
AsLayout
,
...
...
@@ -211,7 +211,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
AddFastGelu
>
;
static
auto
GetInstances
()
...
...
@@ -270,7 +270,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
Add
>>
{
using
DeviceOp
=
DeviceGroupedGemmMultiABDFixedNK
<
AsLayout
,
...
...
@@ -282,7 +282,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
Add
>
;
static
auto
GetInstances
()
...
...
@@ -341,7 +341,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>>
{
using
DeviceOp
=
DeviceGroupedGemmMultiABDFixedNK
<
AsLayout
,
...
...
@@ -353,7 +353,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
FastGelu
>
;
static
auto
GetInstances
()
...
...
@@ -412,7 +412,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedGemmMultiABDFixedNK
<
AsLayout
,
...
...
@@ -424,7 +424,7 @@ struct DeviceOperationInstanceFactory<
DsDataType
,
EDataType
,
PassThrough
,
Scales
,
Multiply
,
PassThrough
>
;
static
auto
GetInstances
()
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp
0 → 100644
View file @
036c5234
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
// fp16_output
void
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmTileLoop
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmTileLoop
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmTileLoop
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedGemmTileLoop
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
// fp16_output
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment